Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
39a6d0e6
Commit
39a6d0e6
authored
Apr 09, 2023
by
Christina Floristean
Browse files
Merging in main branch
parents
d8ee9c5f
84659c93
Changes
101
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2298 additions
and
649 deletions
+2298
-649
openfold/model/evoformer.py
openfold/model/evoformer.py
+567
-133
openfold/model/heads.py
openfold/model/heads.py
+9
-1
openfold/model/model.py
openfold/model/model.py
+203
-99
openfold/model/msa.py
openfold/model/msa.py
+104
-38
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+38
-8
openfold/model/pair_transition.py
openfold/model/pair_transition.py
+6
-6
openfold/model/primitives.py
openfold/model/primitives.py
+179
-61
openfold/model/structure_module.py
openfold/model/structure_module.py
+259
-107
openfold/model/template.py
openfold/model/template.py
+360
-57
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+37
-17
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+331
-36
openfold/np/__init__.py
openfold/np/__init__.py
+0
-16
openfold/np/protein.py
openfold/np/protein.py
+158
-25
openfold/np/relax/__init__.py
openfold/np/relax/__init__.py
+0
-16
openfold/np/relax/amber_minimize.py
openfold/np/relax/amber_minimize.py
+25
-4
openfold/np/relax/cleanup.py
openfold/np/relax/cleanup.py
+8
-2
openfold/np/relax/relax.py
openfold/np/relax/relax.py
+3
-0
openfold/np/relax/utils.py
openfold/np/relax/utils.py
+8
-2
openfold/np/residue_constants.py
openfold/np/residue_constants.py
+3
-3
openfold/utils/__init__.py
openfold/utils/__init__.py
+0
-18
No files found.
openfold/model/evoformer.py
View file @
39a6d0e6
...
@@ -12,11 +12,11 @@
...
@@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
math
import
math
import
sys
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Sequence
,
Optional
from
functools
import
partial
from
functools
import
partial
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.model.primitives
import
Linear
,
LayerNorm
...
@@ -29,6 +29,7 @@ from openfold.model.msa import (
...
@@ -29,6 +29,7 @@ from openfold.model.msa import (
from
openfold.model.outer_product_mean
import
OuterProductMean
from
openfold.model.outer_product_mean
import
OuterProductMean
from
openfold.model.pair_transition
import
PairTransition
from
openfold.model.pair_transition
import
PairTransition
from
openfold.model.triangular_attention
import
(
from
openfold.model.triangular_attention
import
(
TriangleAttention
,
TriangleAttentionStartingNode
,
TriangleAttentionStartingNode
,
TriangleAttentionEndingNode
,
TriangleAttentionEndingNode
,
)
)
...
@@ -37,7 +38,8 @@ from openfold.model.triangular_multiplicative_update import (
...
@@ -37,7 +38,8 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming
,
TriangleMultiplicationIncoming
,
)
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
,
get_checkpoint_fn
from
openfold.utils.checkpointing
import
checkpoint_blocks
,
get_checkpoint_fn
from
openfold.utils.tensor_utils
import
chunk_layer
from
openfold.utils.chunk_utils
import
chunk_layer
,
ChunkSizeTuner
from
openfold.utils.tensor_utils
import
add
class
MSATransition
(
nn
.
Module
):
class
MSATransition
(
nn
.
Module
):
...
@@ -66,6 +68,7 @@ class MSATransition(nn.Module):
...
@@ -66,6 +68,7 @@ class MSATransition(nn.Module):
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_m
,
self
.
c_m
,
init
=
"final"
)
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_m
,
self
.
c_m
,
init
=
"final"
)
def
_transition
(
self
,
m
,
mask
):
def
_transition
(
self
,
m
,
mask
):
m
=
self
.
layer_norm
(
m
)
m
=
self
.
linear_1
(
m
)
m
=
self
.
linear_1
(
m
)
m
=
self
.
relu
(
m
)
m
=
self
.
relu
(
m
)
m
=
self
.
linear_2
(
m
)
*
mask
m
=
self
.
linear_2
(
m
)
*
mask
...
@@ -107,8 +110,6 @@ class MSATransition(nn.Module):
...
@@ -107,8 +110,6 @@ class MSATransition(nn.Module):
mask
=
mask
.
unsqueeze
(
-
1
)
mask
=
mask
.
unsqueeze
(
-
1
)
m
=
self
.
layer_norm
(
m
)
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
)
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
)
else
:
else
:
...
@@ -140,13 +141,13 @@ class PairStack(nn.Module):
...
@@ -140,13 +141,13 @@ class PairStack(nn.Module):
c_hidden_mul
,
c_hidden_mul
,
)
)
self
.
tri_att_start
=
TriangleAttention
StartingNode
(
self
.
tri_att_start
=
TriangleAttention
(
c_z
,
c_z
,
c_hidden_pair_att
,
c_hidden_pair_att
,
no_heads_pair
,
no_heads_pair
,
inf
=
inf
,
inf
=
inf
,
)
)
self
.
tri_att_end
=
TriangleAttention
EndingNode
(
self
.
tri_att_end
=
TriangleAttention
(
c_z
,
c_z
,
c_hidden_pair_att
,
c_hidden_pair_att
,
no_heads_pair
,
no_heads_pair
,
...
@@ -159,32 +160,109 @@ class PairStack(nn.Module):
...
@@ -159,32 +160,109 @@ class PairStack(nn.Module):
)
)
self
.
ps_dropout_row_layer
=
DropoutRowwise
(
pair_dropout
)
self
.
ps_dropout_row_layer
=
DropoutRowwise
(
pair_dropout
)
self
.
ps_dropout_col_layer
=
DropoutColumnwise
(
pair_dropout
)
def
forward
(
def
forward
(
self
,
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
z
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_offload_inference
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# should be disabled to better approximate the exact activations of
# the original.
# the original.
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
))
if
(
_attn_chunk_size
is
None
):
z
=
z
+
self
.
ps_dropout_row_layer
(
_attn_chunk_size
=
chunk_size
self
.
tri_att_start
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
)
m
,
z
=
input_tensors
tmu_update
=
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
,
inplace_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
)
)
z
=
z
+
self
.
ps_dropout_col_layer
(
if
(
not
inplace_safe
):
self
.
tri_att_end
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
)
z
=
z
+
self
.
ps_dropout_row_layer
(
tmu_update
)
else
:
z
=
tmu_update
del
tmu_update
tmu_update
=
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
,
inplace_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
)
)
z
=
z
+
self
.
pair_transition
(
if
(
not
inplace_safe
):
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
z
=
z
+
self
.
ps_dropout_row_layer
(
tmu_update
)
else
:
z
=
tmu_update
del
tmu_update
z
=
add
(
z
,
self
.
ps_dropout_row_layer
(
self
.
tri_att_start
(
z
,
mask
=
pair_mask
,
chunk_size
=
_attn_chunk_size
,
use_memory_efficient_kernel
=
False
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
),
inplace
=
inplace_safe
,
)
z
=
z
.
transpose
(
-
2
,
-
3
)
if
(
inplace_safe
):
input_tensors
[
1
]
=
z
.
contiguous
()
z
=
input_tensors
[
1
]
z
=
add
(
z
,
self
.
ps_dropout_row_layer
(
self
.
tri_att_end
(
z
,
mask
=
pair_mask
.
transpose
(
-
1
,
-
2
),
chunk_size
=
_attn_chunk_size
,
use_memory_efficient_kernel
=
False
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
),
inplace
=
inplace_safe
,
)
z
=
z
.
transpose
(
-
2
,
-
3
)
if
(
inplace_safe
):
input_tensors
[
1
]
=
z
.
contiguous
()
z
=
input_tensors
[
1
]
z
=
add
(
z
,
self
.
pair_transition
(
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
,
),
inplace
=
inplace_safe
,
)
)
return
z
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: GPU
device
=
z
.
device
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
0
]
=
input_tensors
[
0
].
to
(
device
)
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
device
)
m
,
z
=
input_tensors
return
m
,
z
class
EvoformerBlock
(
nn
.
Module
):
class
EvoformerBlock
(
nn
.
Module
):
...
@@ -248,41 +326,134 @@ class EvoformerBlock(nn.Module):
...
@@ -248,41 +326,134 @@ class EvoformerBlock(nn.Module):
)
)
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
Optional
[
torch
.
Tensor
]
,
z
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
,
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
_offloadable_inputs
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
if
(
_offload_inference
and
inplace_safe
):
input_tensors
=
_offloadable_inputs
del
_offloadable_inputs
else
:
input_tensors
=
[
m
,
z
]
m
,
z
=
input_tensors
if
self
.
opm_first
:
if
self
.
opm_first
:
z
=
z
+
self
.
outer_product_mean
(
if
(
_offload_inference
and
inplace_safe
):
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
# m: GPU, z: CPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
cpu
()
m
,
z
=
input_tensors
opm
=
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace_safe
=
inplace_safe
)
)
m
=
m
+
self
.
msa_dropout_layer
(
if
(
_offload_inference
and
inplace_safe
):
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
# m: GPU, z: GPU
)
del
m
,
z
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
opm
.
device
)
m
,
z
=
input_tensors
z
=
add
(
z
,
opm
,
inplace
=
inplace_safe
)
del
opm
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
_attn_chunk_size
,
use_memory_efficient_kernel
=
False
,
use_lma
=
use_lma
,
)
),
inplace
=
inplace_safe
,
)
m
=
m
+
self
.
msa_transition
(
if
(
_offload_inference
and
inplace_safe
):
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
# m: GPU, z: CPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
cpu
()
torch
.
cuda
.
empty_cache
()
m
,
z
=
input_tensors
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
),
inplace
=
inplace_safe
,
)
m
=
add
(
m
,
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
,
),
inplace
=
inplace_safe
,
)
)
if
not
self
.
opm_first
:
if
not
self
.
opm_first
:
z
=
z
+
self
.
outer_product_mean
(
opm
=
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace_safe
=
inplace_safe
)
)
z
=
self
.
pair_stack
(
if
(
_offload_inference
and
inplace_safe
):
z
,
# m: CPU, z: GPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
input_tensors
[
0
]
=
input_tensors
[
0
].
cpu
()
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
opm
.
device
)
m
,
z
=
input_tensors
z
=
add
(
z
,
opm
,
inplace
=
inplace_safe
)
del
opm
elif
(
_offload_inference
and
inplace_safe
):
# m: CPU, z: GPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
device
=
input_tensors
[
0
].
device
input_tensors
[
0
]
=
input_tensors
[
0
].
cpu
()
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
device
)
m
,
z
=
input_tensors
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
z
]
del
m
,
z
m
,
z
=
self
.
pair_stack
(
input_tensors
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
)
)
return
m
,
z
return
m
,
z
...
@@ -358,63 +529,140 @@ class ExtraMSABlock(nn.Module):
...
@@ -358,63 +529,140 @@ class ExtraMSABlock(nn.Module):
)
)
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
Optional
[
torch
.
Tensor
]
,
z
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
,
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_chunk_logits
:
Optional
[
int
]
=
1024
,
use_lma
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
inplace_safe
:
bool
=
False
,
def
add
(
m1
,
m2
):
_mask_trans
:
bool
=
True
,
# The first operation in a checkpoint can't be in-place, but it's
_attn_chunk_size
:
Optional
[
int
]
=
None
,
# nice to have in-place addition during inference. Thus...
_offload_inference
:
bool
=
False
,
if
(
torch
.
is_grad_enabled
()):
_offloadable_inputs
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
m1
=
m1
+
m2
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
else
:
if
(
_attn_chunk_size
is
None
):
m1
+=
m2
_attn_chunk_size
=
chunk_size
if
(
_offload_inference
and
inplace_safe
):
input_tensors
=
_offloadable_inputs
del
_offloadable_inputs
else
:
input_tensors
=
[
m
,
z
]
return
m1
m
,
z
=
input_tensors
if
self
.
opm_first
:
if
self
.
opm_first
:
z
=
z
+
self
.
outer_product_mean
(
if
(
_offload_inference
and
inplace_safe
):
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
# m: GPU, z: CPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
cpu
()
m
,
z
=
input_tensors
opm
=
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace_safe
=
inplace_safe
)
)
m
=
add
(
m
,
self
.
msa_dropout_layer
(
if
(
_offload_inference
and
inplace_safe
):
self
.
msa_att_row
(
# m: GPU, z: GPU
m
.
clone
()
if
torch
.
is_grad_enabled
()
else
m
,
del
m
,
z
z
=
z
.
clone
()
if
torch
.
is_grad_enabled
()
else
z
,
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
mask
=
msa_mask
,
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
opm
.
device
)
chunk_size
=
chunk_size
,
m
,
z
=
input_tensors
use_memory_efficient_kernel
=
not
_chunk_logits
,
_chunk_logits
=
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
z
=
add
(
z
,
opm
,
inplace
=
inplace_safe
)
_checkpoint_chunks
=
del
opm
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
.
clone
()
if
torch
.
is_grad_enabled
()
else
m
,
z
=
z
.
clone
()
if
torch
.
is_grad_enabled
()
else
z
,
mask
=
msa_mask
,
chunk_size
=
_attn_chunk_size
,
use_lma
=
use_lma
,
use_memory_efficient_kernel
=
not
use_lma
and
m
.
is_cuda
,
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
)
),
inplace
=
inplace_safe
,
)
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
z
]
del
m
,
z
def
fn
(
input_tensors
):
m
,
z
=
input_tensors
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: CPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
1
]
=
input_tensors
[
1
].
cpu
()
torch
.
cuda
.
empty_cache
()
m
,
z
=
input_tensors
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
),
inplace
=
inplace_safe
,
)
m
=
add
(
m
,
self
.
msa_transition
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
),
inplace
=
inplace_safe
,
)
)
))
def
fn
(
m
,
z
):
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
))
m
=
add
(
m
,
self
.
msa_transition
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
))
if
not
self
.
opm_first
:
if
not
self
.
opm_first
:
z
=
z
+
self
.
outer_product_mean
(
opm
=
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace_safe
=
inplace_safe
)
)
z
=
self
.
pair_stack
(
if
(
_offload_inference
and
inplace_safe
):
z
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
# m: CPU, z: GPU
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
input_tensors
[
0
]
=
input_tensors
[
0
].
cpu
()
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
opm
.
device
)
m
,
z
=
input_tensors
z
=
add
(
z
,
opm
,
inplace
=
inplace_safe
)
del
opm
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
z
]
del
m
,
z
m
,
z
=
self
.
pair_stack
(
input_tensors
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
)
)
return
m
,
z
return
m
,
z
if
(
torch
.
is_grad_enabled
()
and
self
.
ckpt
):
if
(
torch
.
is_grad_enabled
()
and
self
.
ckpt
):
checkpoint_fn
=
get_checkpoint_fn
()
checkpoint_fn
=
get_checkpoint_fn
()
m
,
z
=
checkpoint_fn
(
fn
,
m
,
z
)
m
,
z
=
checkpoint_fn
(
fn
,
input_tensors
)
else
:
else
:
m
,
z
=
fn
(
m
,
z
)
m
,
z
=
fn
(
input_tensors
)
return
m
,
z
return
m
,
z
...
@@ -446,6 +694,7 @@ class EvoformerStack(nn.Module):
...
@@ -446,6 +694,7 @@ class EvoformerStack(nn.Module):
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
clear_cache_between_blocks
:
bool
=
False
,
clear_cache_between_blocks
:
bool
=
False
,
tune_chunk_size
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
"""
"""
...
@@ -482,6 +731,8 @@ class EvoformerStack(nn.Module):
...
@@ -482,6 +731,8 @@ class EvoformerStack(nn.Module):
clear_cache_between_blocks:
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
stack. Slows down each block but can reduce fragmentation
tune_chunk_size:
Whether to dynamically tune the module's chunk size
"""
"""
super
(
EvoformerStack
,
self
).
__init__
()
super
(
EvoformerStack
,
self
).
__init__
()
...
@@ -511,14 +762,114 @@ class EvoformerStack(nn.Module):
...
@@ -511,14 +762,114 @@ class EvoformerStack(nn.Module):
self
.
linear
=
Linear
(
c_m
,
c_s
)
self
.
linear
=
Linear
(
c_m
,
c_s
)
self
.
tune_chunk_size
=
tune_chunk_size
self
.
chunk_size_tuner
=
None
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
_prep_blocks
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
,
use_flash
:
bool
,
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
inplace_safe
:
bool
,
_mask_trans
:
bool
,
):
blocks
=
[
partial
(
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
if
(
self
.
clear_cache_between_blocks
):
def
block_with_cache_clear
(
block
,
*
args
,
**
kwargs
):
torch
.
cuda
.
empty_cache
()
return
block
(
*
args
,
**
kwargs
)
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
assert
(
not
self
.
training
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
# We don't want to write in-place during chunk tuning runs
args
=
(
m
.
clone
(),
z
.
clone
(),),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
tuned_chunk_size
,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size
=
max
(
chunk_size
,
tuned_chunk_size
//
4
),
)
for
b
in
blocks
]
return
blocks
def
_forward_offload
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
(
not
(
self
.
training
or
torch
.
is_grad_enabled
()))
blocks
=
self
.
_prep_blocks
(
# We are very careful not to create references to these tensors in
# this function
m
=
input_tensors
[
0
],
z
=
input_tensors
[
1
],
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
True
,
_mask_trans
=
_mask_trans
,
)
for
b
in
blocks
:
m
,
z
=
b
(
None
,
None
,
_offload_inference
=
True
,
_offloadable_inputs
=
input_tensors
,
)
input_tensors
[
0
]
=
m
input_tensors
[
1
]
=
z
del
m
,
z
m
,
z
=
input_tensors
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
return
m
,
z
,
s
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]
:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Args:
Args:
m:
m:
...
@@ -529,6 +880,13 @@ class EvoformerStack(nn.Module):
...
@@ -529,6 +880,13 @@ class EvoformerStack(nn.Module):
[*, N_seq, N_res] MSA mask
[*, N_seq, N_res] MSA mask
pair_mask:
pair_mask:
[*, N_res, N_res] pair mask
[*, N_res, N_res] pair mask
chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
use_flash:
Whether to use FlashAttention where possible. Mutually
exclusive with use_lma.
Returns:
Returns:
m:
m:
[*, N_seq, N_res, C_m] MSA embedding
[*, N_seq, N_res, C_m] MSA embedding
...
@@ -536,33 +894,31 @@ class EvoformerStack(nn.Module):
...
@@ -536,33 +894,31 @@ class EvoformerStack(nn.Module):
[*, N_res, N_res, C_z] pair embedding
[*, N_res, N_res, C_z] pair embedding
s:
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
"""
blocks
=
[
blocks
=
self
.
_prep_blocks
(
partial
(
m
=
m
,
b
,
z
=
z
,
msa_mask
=
msa_mask
,
chunk_size
=
chunk_size
,
pair_mask
=
pair_mask
,
use_lma
=
use_lma
,
chunk_size
=
chunk_size
,
use_flash
=
use_flash
,
_mask_trans
=
_mask_trans
,
msa_mask
=
msa_mask
,
)
pair_mask
=
pair_mask
,
for
b
in
self
.
blocks
inplace_safe
=
inplace_safe
,
]
_mask_trans
=
_mask_trans
,
)
if
(
self
.
clear_cache_between_blocks
):
def
block_with_cache_clear
(
block
,
*
args
):
torch
.
cuda
.
empty_cache
()
return
block
(
*
args
)
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
(
not
torch
.
is_grad_enabled
()):
blocks_per_ckpt
=
None
m
,
z
=
checkpoint_blocks
(
m
,
z
=
checkpoint_blocks
(
blocks
,
blocks
,
args
=
(
m
,
z
),
args
=
(
m
,
z
),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
blocks_per_ckpt
=
blocks_per_ckpt
,
)
)
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
return
m
,
z
,
s
return
m
,
z
,
s
...
@@ -570,7 +926,6 @@ class ExtraMSAStack(nn.Module):
...
@@ -570,7 +926,6 @@ class ExtraMSAStack(nn.Module):
"""
"""
Implements Algorithm 18.
Implements Algorithm 18.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
c_m
:
int
,
c_m
:
int
,
c_z
:
int
,
c_z
:
int
,
...
@@ -589,14 +944,13 @@ class ExtraMSAStack(nn.Module):
...
@@ -589,14 +944,13 @@ class ExtraMSAStack(nn.Module):
eps
:
float
,
eps
:
float
,
ckpt
:
bool
,
ckpt
:
bool
,
clear_cache_between_blocks
:
bool
=
False
,
clear_cache_between_blocks
:
bool
=
False
,
chunk_
msa_attn
:
bool
=
False
,
tune_
chunk_
size
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
super
(
ExtraMSAStack
,
self
).
__init__
()
super
(
ExtraMSAStack
,
self
).
__init__
()
self
.
ckpt
=
ckpt
self
.
ckpt
=
ckpt
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
chunk_msa_attn
=
chunk_msa_attn
self
.
blocks
=
nn
.
ModuleList
()
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
no_blocks
):
for
_
in
range
(
no_blocks
):
block
=
ExtraMSABlock
(
block
=
ExtraMSABlock
(
...
@@ -614,16 +968,107 @@ class ExtraMSAStack(nn.Module):
...
@@ -614,16 +968,107 @@ class ExtraMSAStack(nn.Module):
opm_first
=
opm_first
,
opm_first
=
opm_first
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
ckpt
=
ckpt
if
chunk_msa_attn
else
False
,
ckpt
=
False
,
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
self
.
tune_chunk_size
=
tune_chunk_size
self
.
chunk_size_tuner
=
None
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
_prep_blocks
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
,
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
inplace_safe
:
bool
,
_mask_trans
:
bool
,
):
blocks
=
[
partial
(
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
def
clear_cache
(
b
,
*
args
,
**
kwargs
):
torch
.
cuda
.
empty_cache
()
return
b
(
*
args
,
**
kwargs
)
if
(
self
.
clear_cache_between_blocks
):
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
# Tensors cloned to avoid getting written to in-place
# A corollary is that chunk size tuning should be disabled for
# large N, when z gets really big
args
=
(
m
.
clone
(),
z
.
clone
(),),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
tuned_chunk_size
,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size
=
max
(
chunk_size
,
tuned_chunk_size
//
4
),
)
for
b
in
blocks
]
return
blocks
def
_forward_offload
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
chunk_size
:
int
,
use_lma
:
bool
=
False
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
torch
.
Tensor
:
assert
(
not
(
self
.
training
or
torch
.
is_grad_enabled
()))
blocks
=
self
.
_prep_blocks
(
# We are very careful not to create references to these tensors in
# this function
m
=
input_tensors
[
0
],
z
=
input_tensors
[
1
],
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
True
,
_mask_trans
=
_mask_trans
,
)
for
b
in
blocks
:
m
,
z
=
b
(
None
,
None
,
_offload_inference
=
True
,
_offloadable_inputs
=
input_tensors
,
)
input_tensors
[
0
]
=
m
input_tensors
[
1
]
=
z
del
m
,
z
return
input_tensors
[
1
]
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
chunk_size
:
int
,
chunk_size
:
int
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
Non
e
,
use_lma
:
bool
=
Fals
e
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
Non
e
,
inplace_safe
:
bool
=
Fals
e
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
...
@@ -632,6 +1077,8 @@ class ExtraMSAStack(nn.Module):
...
@@ -632,6 +1077,8 @@ class ExtraMSAStack(nn.Module):
[*, N_extra, N_res, C_m] extra MSA embedding
[*, N_extra, N_res, C_m] extra MSA embedding
z:
z:
[*, N_res, N_res, C_z] pair embedding
[*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules
use_lma: Whether to use low-memory attention during inference
msa_mask:
msa_mask:
Optional [*, N_extra, N_res] MSA mask
Optional [*, N_extra, N_res] MSA mask
pair_mask:
pair_mask:
...
@@ -639,35 +1086,22 @@ class ExtraMSAStack(nn.Module):
...
@@ -639,35 +1086,22 @@ class ExtraMSAStack(nn.Module):
Returns:
Returns:
[*, N_res, N_res, C_z] pair update
[*, N_res, N_res, C_z] pair update
"""
"""
if
(
not
self
.
chunk_msa_attn
):
checkpoint_fn
=
get_checkpoint_fn
()
checkpoint_fn
=
get_checkpoint_fn
()
blocks
=
self
.
_prep_blocks
(
blocks
=
[
m
=
m
,
partial
(
z
=
z
,
b
,
chunk_size
=
chunk_size
,
msa_mask
=
msa_mask
,
use_lma
=
use_lma
,
pair_mask
=
pair_mask
,
msa_mask
=
msa_mask
,
chunk_size
=
chunk_size
,
pair_mask
=
pair_mask
,
_chunk_logits
=
None
inplace_safe
=
inplace_safe
,
)
for
b
in
self
.
blocks
_mask_trans
=
_mask_trans
,
]
)
def
clear_cache
(
b
,
*
args
):
torch
.
cuda
.
empty_cache
()
return
b
(
*
args
)
if
(
self
.
clear_cache_between_blocks
):
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
for
b
in
blocks
:
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
m
,
z
=
checkpoint_fn
(
b
,
*
(
m
,
z
))
else
:
m
,
z
=
b
(
m
,
z
)
else
:
for
b
in
self
.
blocks
:
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
)
if
(
self
.
clear_cache_between_blocks
):
for
b
in
blocks
:
torch
.
cuda
.
empty_cache
()
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
m
,
z
=
checkpoint_fn
(
b
,
m
,
z
)
else
:
m
,
z
=
b
(
m
,
z
)
return
z
return
z
openfold/model/heads.py
View file @
39a6d0e6
...
@@ -22,6 +22,7 @@ from openfold.utils.loss import (
...
@@ -22,6 +22,7 @@ from openfold.utils.loss import (
compute_tm
,
compute_tm
,
compute_predicted_aligned_error
,
compute_predicted_aligned_error
,
)
)
from
openfold.utils.precision_utils
import
is_fp16_enabled
class
AuxiliaryHeads
(
nn
.
Module
):
class
AuxiliaryHeads
(
nn
.
Module
):
...
@@ -137,7 +138,7 @@ class DistogramHead(nn.Module):
...
@@ -137,7 +138,7 @@ class DistogramHead(nn.Module):
self
.
linear
=
Linear
(
self
.
c_z
,
self
.
no_bins
,
init
=
"final"
)
self
.
linear
=
Linear
(
self
.
c_z
,
self
.
no_bins
,
init
=
"final"
)
def
forward
(
self
,
z
):
# [*, N, N, C_z]
def
_
forward
(
self
,
z
):
# [*, N, N, C_z]
"""
"""
Args:
Args:
z:
z:
...
@@ -149,6 +150,13 @@ class DistogramHead(nn.Module):
...
@@ -149,6 +150,13 @@ class DistogramHead(nn.Module):
logits
=
self
.
linear
(
z
)
logits
=
self
.
linear
(
z
)
logits
=
logits
+
logits
.
transpose
(
-
2
,
-
3
)
logits
=
logits
+
logits
.
transpose
(
-
2
,
-
3
)
return
logits
return
logits
def
forward
(
self
,
z
):
if
(
is_fp16_enabled
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
self
.
_forward
(
z
.
float
())
else
:
return
self
.
_forward
(
z
)
class
TMScoreHead
(
nn
.
Module
):
class
TMScoreHead
(
nn
.
Module
):
...
...
openfold/model/model.py
View file @
39a6d0e6
...
@@ -12,8 +12,9 @@
...
@@ -12,8 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partial
from
functools
import
partial
import
weakref
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -34,12 +35,26 @@ from openfold.model.embedders import (
...
@@ -34,12 +35,26 @@ from openfold.model.embedders import (
)
)
from
openfold.model.evoformer
import
EvoformerStack
,
ExtraMSAStack
from
openfold.model.evoformer
import
EvoformerStack
,
ExtraMSAStack
from
openfold.model.heads
import
AuxiliaryHeads
from
openfold.model.heads
import
AuxiliaryHeads
import
openfold.np.residue_constants
as
residue_constants
from
openfold.model.structure_module
import
StructureModule
from
openfold.model.structure_module
import
StructureModule
from
openfold.model.template
import
(
TemplatePairStack
,
TemplatePointwiseAttention
,
embed_templates_average
,
embed_templates_offload
,
)
import
openfold.np.residue_constants
as
residue_constants
from
openfold.utils.feats
import
(
pseudo_beta_fn
,
build_extra_msa_feat
,
build_template_angle_feat
,
build_template_pair_feat
,
atom14_to_atom37
,
)
from
openfold.utils.loss
import
(
from
openfold.utils.loss
import
(
compute_plddt
,
compute_plddt
,
)
)
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
add
,
dict_multimap
,
dict_multimap
,
tensor_tree_map
,
tensor_tree_map
,
)
)
...
@@ -61,55 +76,96 @@ class AlphaFold(nn.Module):
...
@@ -61,55 +76,96 @@ class AlphaFold(nn.Module):
super
(
AlphaFold
,
self
).
__init__
()
super
(
AlphaFold
,
self
).
__init__
()
self
.
globals
=
config
.
globals
self
.
globals
=
config
.
globals
config
=
config
.
model
self
.
config
=
config
.
model
template_config
=
config
.
template
self
.
template_config
=
self
.
config
.
template
extra_msa_config
=
config
.
extra_msa
self
.
extra_msa_config
=
self
.
config
.
extra_msa
# Main trunk + structure module
# Main trunk + structure module
if
(
self
.
globals
.
is_multimer
):
if
(
self
.
globals
.
is_multimer
):
self
.
input_embedder
=
InputEmbedderMultimer
(
self
.
input_embedder
=
InputEmbedderMultimer
(
**
config
[
"input_embedder"
],
**
self
.
config
[
"input_embedder"
],
)
)
else
:
else
:
self
.
input_embedder
=
InputEmbedder
(
self
.
input_embedder
=
InputEmbedder
(
**
config
[
"input_embedder"
],
**
self
.
config
[
"input_embedder"
],
)
)
self
.
recycling_embedder
=
RecyclingEmbedder
(
self
.
recycling_embedder
=
RecyclingEmbedder
(
**
config
[
"recycling_embedder"
],
**
self
.
config
[
"recycling_embedder"
],
)
)
if
(
self
.
globals
.
is_multimer
):
if
(
self
.
template_config
.
enabled
):
self
.
template_embedder
=
TemplateEmbedderMultimer
(
if
(
self
.
globals
.
is_multimer
):
template_config
,
self
.
template_embedder
=
TemplateEmbedderMultimer
(
self
.
template_config
,
)
else
:
self
.
template_embedder
=
TemplateEmbedder
(
self
.
template_config
,
)
if
(
self
.
extra_msa_config
.
enabled
):
self
.
extra_msa_embedder
=
ExtraMSAEmbedder
(
**
self
.
extra_msa_config
[
"extra_msa_embedder"
],
)
)
else
:
self
.
extra_msa_stack
=
ExtraMSAStack
(
self
.
template_embedder
=
TemplateEmbedder
(
**
self
.
extra_msa_config
[
"extra_msa_stack"
],
template_config
,
)
)
self
.
extra_msa_embedder
=
ExtraMSAEmbedder
(
**
extra_msa_config
[
"extra_msa_embedder"
],
)
self
.
extra_msa_stack
=
ExtraMSAStack
(
**
extra_msa_config
[
"extra_msa_stack"
],
)
self
.
evoformer
=
EvoformerStack
(
self
.
evoformer
=
EvoformerStack
(
**
config
[
"evoformer_stack"
],
**
self
.
config
[
"evoformer_stack"
],
)
)
self
.
structure_module
=
StructureModule
(
self
.
structure_module
=
StructureModule
(
is_multimer
=
self
.
globals
.
is_multimer
,
is_multimer
=
self
.
globals
.
is_multimer
,
**
config
[
"structure_module"
],
**
self
.
config
[
"structure_module"
],
)
)
self
.
aux_heads
=
AuxiliaryHeads
(
self
.
aux_heads
=
AuxiliaryHeads
(
config
[
"heads"
],
self
.
config
[
"heads"
],
)
)
self
.
config
=
config
def
embed_templates
(
self
,
batch
,
feats
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
):
if
(
self
.
globals
.
is_multimer
):
asym_id
=
feats
[
"asym_id"
]
multichain_mask_2d
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
)
template_embeds
=
self
.
template_embedder
(
batch
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
templ_dim
,
chunk_size
=
self
.
globals
.
chunk_size
,
multichain_mask_2d
=
multichain_mask_2d
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
)
feats
[
"template_torsion_angles_mask"
]
=
(
template_embeds
[
"template_mask"
]
)
else
:
if
(
self
.
template_config
.
offload_templates
):
return
embed_templates_offload
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
=
inplace_safe
,
)
elif
(
self
.
template_config
.
average_templates
):
return
embed_templates_average
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
=
inplace_safe
,
)
template_embeds
=
self
.
template_embedder
(
batch
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
templ_dim
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
)
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
,
_recycle
=
True
):
return
template_embeds
def
iteration
(
self
,
feats
,
prevs
,
_recycle
=
True
):
# Primary output dictionary
# Primary output dictionary
outputs
=
{}
outputs
=
{}
...
@@ -125,19 +181,38 @@ class AlphaFold(nn.Module):
...
@@ -125,19 +181,38 @@ class AlphaFold(nn.Module):
n
=
feats
[
"target_feat"
].
shape
[
-
2
]
n
=
feats
[
"target_feat"
].
shape
[
-
2
]
n_seq
=
feats
[
"msa_feat"
].
shape
[
-
3
]
n_seq
=
feats
[
"msa_feat"
].
shape
[
-
3
]
device
=
feats
[
"target_feat"
].
device
device
=
feats
[
"target_feat"
].
device
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# Prep some features
# Prep some features
seq_mask
=
feats
[
"seq_mask"
]
seq_mask
=
feats
[
"seq_mask"
]
pair_mask
=
seq_mask
[...,
None
]
*
seq_mask
[...,
None
,
:]
pair_mask
=
seq_mask
[...,
None
]
*
seq_mask
[...,
None
,
:]
msa_mask
=
feats
[
"msa_mask"
]
msa_mask
=
feats
[
"msa_mask"
]
## Initialize the MSA and pair representations
# Initialize the MSA and pair representations
if
(
self
.
globals
.
is_multimer
):
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m
,
z
=
self
.
input_embedder
(
feats
)
# m: [*, S_c, N, C_m]
else
:
# z: [*, N, N, C_z]
# m: [*, S_c, N, C_m]
m
,
z
=
self
.
input_embedder
(
feats
)
# z: [*, N, N, C_z]
m
,
z
=
self
.
input_embedder
(
feats
[
"target_feat"
],
feats
[
"residue_index"
],
feats
[
"msa_feat"
],
inplace_safe
=
inplace_safe
,
)
# Initialize the recycling embeddings, if needs be
# Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function, saving memory
m_1_prev
,
z_prev
,
x_prev
=
reversed
([
prevs
.
pop
()
for
_
in
range
(
3
)])
# Initialize the recycling embeddings, if needs be
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
# [*, N, C_m]
# [*, N, C_m]
m_1_prev
=
m
.
new_zeros
(
m_1_prev
=
m
.
new_zeros
(
...
@@ -161,69 +236,58 @@ class AlphaFold(nn.Module):
...
@@ -161,69 +236,58 @@ class AlphaFold(nn.Module):
feats
[
"aatype"
],
x_prev
,
None
feats
[
"aatype"
],
x_prev
,
None
).
to
(
dtype
=
z
.
dtype
)
).
to
(
dtype
=
z
.
dtype
)
# The recycling embedder is memory-intensive, so we offload first
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
m
=
m
.
cpu
()
z
=
z
.
cpu
()
# m_1_prev_emb: [*, N, C_m]
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb
,
z_prev_emb
=
self
.
recycling_embedder
(
m_1_prev_emb
,
z_prev_emb
=
self
.
recycling_embedder
(
m_1_prev
,
m_1_prev
,
z_prev
,
z_prev
,
x_prev
,
x_prev
,
inplace_safe
=
inplace_safe
,
)
)
# If the number of recycling iterations is 0, skip recycling
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
# altogether. We zero them this way instead of computing them
m
=
m
.
to
(
m_1_prev_emb
.
device
)
# conditionally to avoid leaving parameters unused, which has annoying
z
=
z
.
to
(
z_prev
.
device
)
# implications for DDP training.
# EDIT: This has since been removed from the official codebase (2cd61a)
# if(not _recycle):
# m_1_prev_emb *= 0
# z_prev_emb *= 0
# [*, S_c, N, C_m]
# [*, S_c, N, C_m]
m
[...,
0
,
:,
:]
+=
m_1_prev_emb
m
[...,
0
,
:,
:]
+=
m_1_prev_emb
# [*, N, N, C_z]
# [*, N, N, C_z]
z
+
=
z_prev_emb
z
=
add
(
z
,
z_prev_emb
,
inplace
=
inplace_safe
)
# Possibly prevents memory fragmentation
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
# Embed the templates + merge with MSA/pair embeddings
if
self
.
config
.
template
.
enabled
:
if
self
.
config
.
template
.
enabled
:
template_feats
=
{
template_feats
=
{
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
}
}
if
(
self
.
globals
.
is_multimer
):
template_embeds
=
self
.
embed_templates
(
asym_id
=
feats
[
"asym_id"
]
template_feats
,
multichain_mask_2d
=
(
feats
,
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
z
,
)
pair_mask
.
to
(
dtype
=
z
.
dtype
),
template_embeds
=
self
.
template_embedder
(
no_batch_dims
,
template_feats
,
inplace_safe
=
inplace_safe
,
z
,
)
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
chunk_size
=
self
.
globals
.
chunk_size
,
multichain_mask_2d
=
multichain_mask_2d
,
)
feats
[
"template_torsion_angles_mask"
]
=
(
template_embeds
[
"template_mask"
]
)
else
:
template_embeds
=
self
.
template_embedder
(
template_feats
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
self
.
globals
.
chunk_size
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
z
=
add
(
z
,
template_embeds
.
pop
(
"template_pair_embedding"
),
inplace_safe
,
)
if
(
if
(
self
.
config
.
template
.
embed_angles
or
"template_single_embedding"
in
template_embeds
(
self
.
globals
.
is_multimer
and
self
.
config
.
template
.
enabled
)
):
):
# [*, S = S_c + S_t, N, C_m]
# [*, S = S_c + S_t, N, C_m]
m
=
torch
.
cat
(
m
=
torch
.
cat
(
...
@@ -253,41 +317,80 @@ class AlphaFold(nn.Module):
...
@@ -253,41 +317,80 @@ class AlphaFold(nn.Module):
# [*, S_e, N, C_e]
# [*, S_e, N, C_e]
extra_msa_feat
=
extra_msa_fn
(
feats
)
extra_msa_feat
=
extra_msa_fn
(
feats
)
extra_msa_feat
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
a
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
if
(
self
.
globals
.
offload_inference
):
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors
=
[
a
,
z
]
del
a
,
z
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
.
_forward_offload
(
input_tensors
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
)
del
input_tensors
else
:
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
(
a
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
# [*, N, N, C_z]
# Run MSA + pair embeddings through the trunk of the network
z
=
self
.
extra_msa_stack
(
# m: [*, S, N, C_m]
extra_msa_feat
,
# z: [*, N, N, C_z]
z
,
# s: [*, N, C_s]
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
extra_msa_feat
.
dtype
),
if
(
self
.
globals
.
offload_inference
):
input_tensors
=
[
m
,
z
]
del
m
,
z
m
,
z
,
s
=
self
.
evoformer
.
_forward_offload
(
input_tensors
,
msa_mask
=
msa_mask
.
to
(
dtype
=
input_tensors
[
0
].
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
input_tensors
[
1
].
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
del
input_tensors
else
:
m
,
z
,
s
=
self
.
evoformer
(
m
,
z
,
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
use_flash
=
self
.
globals
.
use_flash
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
m
,
z
,
s
=
self
.
evoformer
(
m
,
z
,
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"pair"
]
=
z
outputs
[
"pair"
]
=
z
outputs
[
"single"
]
=
s
outputs
[
"single"
]
=
s
del
z
# Predict 3D structure
# Predict 3D structure
outputs
[
"sm"
]
=
self
.
structure_module
(
outputs
[
"sm"
]
=
self
.
structure_module
(
s
,
outputs
,
z
,
feats
[
"aatype"
],
feats
[
"aatype"
],
mask
=
feats
[
"seq_mask"
].
to
(
dtype
=
s
.
dtype
),
mask
=
feats
[
"seq_mask"
].
to
(
dtype
=
s
.
dtype
),
inplace_safe
=
inplace_safe
,
_offload_inference
=
self
.
globals
.
offload_inference
,
)
)
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
...
@@ -301,7 +404,7 @@ class AlphaFold(nn.Module):
...
@@ -301,7 +404,7 @@ class AlphaFold(nn.Module):
m_1_prev
=
m
[...,
0
,
:,
:]
m_1_prev
=
m
[...,
0
,
:,
:]
# [*, N, N, C_z]
# [*, N, N, C_z]
z_prev
=
z
z_prev
=
outputs
[
"pair"
]
# [*, N, 3]
# [*, N, 3]
x_prev
=
outputs
[
"final_atom_positions"
]
x_prev
=
outputs
[
"final_atom_positions"
]
...
@@ -379,14 +482,13 @@ class AlphaFold(nn.Module):
...
@@ -379,14 +482,13 @@ class AlphaFold(nn.Module):
"""
"""
# Initialize recycling embeddings
# Initialize recycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled
=
torch
.
is_grad_enabled
()
is_grad_enabled
=
torch
.
is_grad_enabled
()
self
.
_disable_activation_checkpointing
()
# Main recycling loop
# Main recycling loop
num_iters
=
batch
[
"aatype"
].
shape
[
-
1
]
num_iters
=
batch
[
"aatype"
].
shape
[
-
1
]
for
cycle_no
in
range
(
num_iters
):
for
cycle_no
in
range
(
num_iters
):
# Select the features for the current recycling cycle
# Select the features for the current recycling cycle
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
...
@@ -395,7 +497,6 @@ class AlphaFold(nn.Module):
...
@@ -395,7 +497,6 @@ class AlphaFold(nn.Module):
is_final_iter
=
cycle_no
==
(
num_iters
-
1
)
is_final_iter
=
cycle_no
==
(
num_iters
-
1
)
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
if
is_final_iter
:
if
is_final_iter
:
self
.
_enable_activation_checkpointing
()
# Sidestep AMP bug (PyTorch issue #65766)
# Sidestep AMP bug (PyTorch issue #65766)
if
torch
.
is_autocast_enabled
():
if
torch
.
is_autocast_enabled
():
torch
.
clear_autocast_cache
()
torch
.
clear_autocast_cache
()
...
@@ -403,12 +504,15 @@ class AlphaFold(nn.Module):
...
@@ -403,12 +504,15 @@ class AlphaFold(nn.Module):
# Run the next iteration of the model
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
feats
,
m_1_prev
,
prevs
,
z_prev
,
x_prev
,
_recycle
=
(
num_iters
>
1
)
_recycle
=
(
num_iters
>
1
)
)
)
if
(
not
is_final_iter
):
del
outputs
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
del
m_1_prev
,
z_prev
,
x_prev
# Run auxiliary heads
# Run auxiliary heads
outputs
.
update
(
self
.
aux_heads
(
outputs
))
outputs
.
update
(
self
.
aux_heads
(
outputs
))
...
...
openfold/model/msa.py
View file @
39a6d0e6
...
@@ -26,8 +26,8 @@ from openfold.model.primitives import (
...
@@ -26,8 +26,8 @@ from openfold.model.primitives import (
_attention_chunked_trainable
,
_attention_chunked_trainable
,
)
)
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.chunk_utils
import
chunk_layer
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
)
)
...
@@ -89,21 +89,38 @@ class MSAAttention(nn.Module):
...
@@ -89,21 +89,38 @@ class MSAAttention(nn.Module):
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
biases
:
Optional
[
List
[
torch
.
Tensor
]],
use_memory_efficient_kernel
:
bool
,
chunk_size
:
int
,
chunk_size
:
int
,
use_memory_efficient_kernel
:
bool
,
use_lma
:
bool
,
use_flash
:
bool
,
flash_mask
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
mha
=
partial
(
def
fn
(
m
,
biases
,
flash_mask
):
self
.
mha
,
m
=
self
.
layer_norm_m
(
m
)
use_memory_efficient_kernel
=
use_memory_efficient_kernel
return
self
.
mha
(
)
q_x
=
m
,
kv_x
=
m
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
flash_mask
=
flash_mask
,
)
inputs
=
{
"m"
:
m
}
if
(
biases
is
not
None
):
inputs
[
"biases"
]
=
biases
else
:
fn
=
partial
(
fn
,
biases
=
None
)
if
(
use_flash
and
flash_mask
is
not
None
):
inputs
[
"flash_mask"
]
=
flash_mask
else
:
fn
=
partial
(
fn
,
flash_mask
=
None
)
return
chunk_layer
(
return
chunk_layer
(
mha
,
fn
,
{
inputs
,
"q_x"
:
m
,
"kv_x"
:
m
,
"biases"
:
biases
,
},
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
])
no_batch_dims
=
len
(
m
.
shape
[:
-
2
])
)
)
...
@@ -111,11 +128,9 @@ class MSAAttention(nn.Module):
...
@@ -111,11 +128,9 @@ class MSAAttention(nn.Module):
def
_prep_inputs
(
self
,
def
_prep_inputs
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
],
z
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
]
mask
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
inplace_safe
:
bool
=
False
,
# [*, N_seq, N_res, C_m]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
m
=
self
.
layer_norm_m
(
m
)
n_seq
,
n_res
=
m
.
shape
[
-
3
:
-
1
]
n_seq
,
n_res
=
m
.
shape
[
-
3
:
-
1
]
if
mask
is
None
:
if
mask
is
None
:
# [*, N_seq, N_res]
# [*, N_seq, N_res]
...
@@ -131,11 +146,20 @@ class MSAAttention(nn.Module):
...
@@ -131,11 +146,20 @@ class MSAAttention(nn.Module):
self
.
layer_norm_z
is
not
None
and
# benefit of
self
.
layer_norm_z
is
not
None
and
# benefit of
self
.
linear_z
is
not
None
# TorchScript
self
.
linear_z
is
not
None
# TorchScript
):
):
# [*, N_res, N_res, C_z]
chunks
=
[]
z
=
self
.
layer_norm_z
(
z
)
for
i
in
range
(
0
,
z
.
shape
[
-
3
],
256
):
z_chunk
=
z
[...,
i
:
i
+
256
,
:,
:]
# [*, N_res, N_res, C_z]
z_chunk
=
self
.
layer_norm_z
(
z_chunk
)
# [*, N_res, N_res, no_heads]
z_chunk
=
self
.
linear_z
(
z_chunk
)
chunks
.
append
(
z_chunk
)
# [*, N_res, N_res, no_heads]
z
=
torch
.
cat
(
chunks
,
dim
=-
3
)
z
=
self
.
linear_z
(
z
)
# [*, 1, no_heads, N_res, N_res]
# [*, 1, no_heads, N_res, N_res]
z
=
permute_final_dims
(
z
,
(
2
,
0
,
1
)).
unsqueeze
(
-
4
)
z
=
permute_final_dims
(
z
,
(
2
,
0
,
1
)).
unsqueeze
(
-
4
)
...
@@ -149,6 +173,7 @@ class MSAAttention(nn.Module):
...
@@ -149,6 +173,7 @@ class MSAAttention(nn.Module):
mask
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
],
chunk_logits
:
int
,
chunk_logits
:
int
,
checkpoint
:
bool
,
checkpoint
:
bool
,
inplace_safe
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
MSA attention with training-time chunking of the softmax computation.
MSA attention with training-time chunking of the softmax computation.
...
@@ -158,7 +183,10 @@ class MSAAttention(nn.Module):
...
@@ -158,7 +183,10 @@ class MSAAttention(nn.Module):
MSA_DIM
=
-
4
MSA_DIM
=
-
4
def
_get_qkv
(
m
,
z
):
def
_get_qkv
(
m
,
z
):
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
,
inplace_safe
=
inplace_safe
)
m
=
self
.
layer_norm_m
(
m
)
q
,
k
,
v
=
self
.
mha
.
_prep_qkv
(
m
,
m
)
q
,
k
,
v
=
self
.
mha
.
_prep_qkv
(
m
,
m
)
return
m
,
q
,
k
,
v
,
mask_bias
,
z
return
m
,
q
,
k
,
v
,
mask_bias
,
z
...
@@ -193,6 +221,9 @@ class MSAAttention(nn.Module):
...
@@ -193,6 +221,9 @@ class MSAAttention(nn.Module):
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -214,23 +245,43 @@ class MSAAttention(nn.Module):
...
@@ -214,23 +245,43 @@ class MSAAttention(nn.Module):
if
(
_chunk_logits
is
not
None
):
if
(
_chunk_logits
is
not
None
):
return
self
.
_chunked_msa_attn
(
return
self
.
_chunked_msa_attn
(
m
=
m
,
z
=
z
,
mask
=
mask
,
m
=
m
,
z
=
z
,
mask
=
mask
,
chunk_logits
=
_chunk_logits
,
checkpoint
=
_checkpoint_chunks
chunk_logits
=
_chunk_logits
,
)
checkpoint
=
_checkpoint_chunks
,
inplace_safe
=
inplace_safe
,
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
)
biases
=
[
mask_bias
]
if
(
use_flash
):
if
(
z
is
not
None
):
assert
z
is
None
biases
.
append
(
z
)
biases
=
None
else
:
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
,
inplace_safe
=
inplace_safe
)
biases
=
[
mask_bias
]
if
(
z
is
not
None
):
biases
.
append
(
z
)
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
biases
,
use_memory_efficient_kernel
,
chunk_size
)
m
=
self
.
_chunk
(
m
,
biases
,
chunk_size
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
flash_mask
=
mask
,
)
else
:
else
:
m
=
self
.
layer_norm_m
(
m
)
m
=
self
.
mha
(
m
=
self
.
mha
(
q_x
=
m
,
q_x
=
m
,
kv_x
=
m
,
kv_x
=
m
,
biases
=
biases
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
flash_mask
=
mask
,
)
)
return
m
return
m
...
@@ -305,7 +356,8 @@ class MSAColumnAttention(nn.Module):
...
@@ -305,7 +356,8 @@ class MSAColumnAttention(nn.Module):
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -323,7 +375,13 @@ class MSAColumnAttention(nn.Module):
...
@@ -323,7 +375,13 @@ class MSAColumnAttention(nn.Module):
if
mask
is
not
None
:
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
2
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
m
=
self
.
_msa_att
(
m
,
mask
=
mask
,
chunk_size
=
chunk_size
)
m
=
self
.
_msa_att
(
m
,
mask
=
mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
)
# [*, N_seq, N_res, C_in]
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
...
@@ -360,13 +418,19 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -360,13 +418,19 @@ class MSAColumnGlobalAttention(nn.Module):
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
mha_input
=
{
mha_input
=
{
"m"
:
m
,
"m"
:
m
,
"mask"
:
mask
,
"mask"
:
mask
,
}
}
def
fn
(
m
,
mask
):
m
=
self
.
layer_norm_m
(
m
)
return
self
.
global_attention
(
m
,
mask
,
use_lma
=
use_lma
)
return
chunk_layer
(
return
chunk_layer
(
self
.
global_attentio
n
,
f
n
,
mha_input
,
mha_input
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
...
@@ -377,6 +441,7 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -377,6 +441,7 @@ class MSAColumnGlobalAttention(nn.Module):
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
...
@@ -393,12 +458,13 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -393,12 +458,13 @@ class MSAColumnGlobalAttention(nn.Module):
mask
=
mask
.
transpose
(
-
1
,
-
2
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
# [*, N_res, N_seq, C_in]
# [*, N_res, N_seq, C_in]
m
=
self
.
layer_norm_m
(
m
)
#
m = self.layer_norm_m(m)
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
)
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
,
use_lma
=
use_lma
)
else
:
else
:
m
=
self
.
global_attention
(
m
=
m
,
mask
=
mask
)
m
=
self
.
layer_norm_m
(
m
)
m
=
self
.
global_attention
(
m
=
m
,
mask
=
mask
,
use_lma
=
use_lma
)
# [*, N_seq, N_res, C_in]
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
...
...
openfold/model/outer_product_mean.py
View file @
39a6d0e6
...
@@ -20,7 +20,8 @@ import torch
...
@@ -20,7 +20,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
from
openfold.model.primitives
import
Linear
from
openfold.utils.tensor_utils
import
chunk_layer
from
openfold.utils.chunk_utils
import
chunk_layer
from
openfold.utils.precision_utils
import
is_fp16_enabled
class
OuterProductMean
(
nn
.
Module
):
class
OuterProductMean
(
nn
.
Module
):
...
@@ -82,15 +83,22 @@ class OuterProductMean(nn.Module):
...
@@ -82,15 +83,22 @@ class OuterProductMean(nn.Module):
no_batch_dims
=
1
,
no_batch_dims
=
1
,
)
)
out
.
append
(
outer
)
out
.
append
(
outer
)
outer
=
torch
.
stack
(
out
,
dim
=
0
)
# For some cursed reason making this distinction saves memory
if
(
len
(
out
)
==
1
):
outer
=
out
[
0
].
unsqueeze
(
0
)
else
:
outer
=
torch
.
stack
(
out
,
dim
=
0
)
outer
=
outer
.
reshape
(
a
.
shape
[:
-
3
]
+
outer
.
shape
[
1
:])
outer
=
outer
.
reshape
(
a
.
shape
[:
-
3
]
+
outer
.
shape
[
1
:])
return
outer
return
outer
def
forward
(
self
,
def
_
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
chunk_size
:
Optional
[
int
]
=
None
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -105,12 +113,17 @@ class OuterProductMean(nn.Module):
...
@@ -105,12 +113,17 @@ class OuterProductMean(nn.Module):
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
# [*, N_seq, N_res, C_m]
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm
(
m
)
ln
=
self
.
layer_norm
(
m
)
# [*, N_seq, N_res, C]
# [*, N_seq, N_res, C]
mask
=
mask
.
unsqueeze
(
-
1
)
mask
=
mask
.
unsqueeze
(
-
1
)
a
=
self
.
linear_1
(
m
)
*
mask
a
=
self
.
linear_1
(
ln
)
b
=
self
.
linear_2
(
m
)
*
mask
a
=
a
*
mask
b
=
self
.
linear_2
(
ln
)
b
=
b
*
mask
del
ln
a
=
a
.
transpose
(
-
2
,
-
3
)
a
=
a
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
...
@@ -122,8 +135,25 @@ class OuterProductMean(nn.Module):
...
@@ -122,8 +135,25 @@ class OuterProductMean(nn.Module):
# [*, N_res, N_res, 1]
# [*, N_res, N_res, 1]
norm
=
torch
.
einsum
(
"...abc,...adc->...bdc"
,
mask
,
mask
)
norm
=
torch
.
einsum
(
"...abc,...adc->...bdc"
,
mask
,
mask
)
norm
=
norm
+
self
.
eps
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
outer
=
outer
/
(
self
.
eps
+
norm
)
if
(
inplace_safe
):
outer
/=
norm
else
:
outer
=
outer
/
norm
return
outer
return
outer
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
(
is_fp16_enabled
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
self
.
_forward
(
m
.
float
(),
mask
,
chunk_size
,
inplace_safe
)
else
:
return
self
.
_forward
(
m
,
mask
,
chunk_size
,
inplace_safe
)
openfold/model/pair_transition.py
View file @
39a6d0e6
...
@@ -18,7 +18,7 @@ import torch
...
@@ -18,7 +18,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.
tensor
_utils
import
chunk_layer
from
openfold.utils.
chunk
_utils
import
chunk_layer
class
PairTransition
(
nn
.
Module
):
class
PairTransition
(
nn
.
Module
):
...
@@ -46,12 +46,16 @@ class PairTransition(nn.Module):
...
@@ -46,12 +46,16 @@ class PairTransition(nn.Module):
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_z
,
c_z
,
init
=
"final"
)
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_z
,
c_z
,
init
=
"final"
)
def
_transition
(
self
,
z
,
mask
):
def
_transition
(
self
,
z
,
mask
):
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm
(
z
)
# [*, N_res, N_res, C_hidden]
# [*, N_res, N_res, C_hidden]
z
=
self
.
linear_1
(
z
)
z
=
self
.
linear_1
(
z
)
z
=
self
.
relu
(
z
)
z
=
self
.
relu
(
z
)
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
z
=
self
.
linear_2
(
z
)
*
mask
z
=
self
.
linear_2
(
z
)
z
=
z
*
mask
return
z
return
z
...
@@ -68,7 +72,6 @@ class PairTransition(nn.Module):
...
@@ -68,7 +72,6 @@ class PairTransition(nn.Module):
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
)
def
forward
(
self
,
def
forward
(
self
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -88,9 +91,6 @@ class PairTransition(nn.Module):
...
@@ -88,9 +91,6 @@ class PairTransition(nn.Module):
# [*, N_res, N_res, 1]
# [*, N_res, N_res, 1]
mask
=
mask
.
unsqueeze
(
-
1
)
mask
=
mask
.
unsqueeze
(
-
1
)
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm
(
z
)
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
z
=
self
.
_chunk
(
z
,
mask
,
chunk_size
)
z
=
self
.
_chunk
(
z
,
mask
,
chunk_size
)
else
:
else
:
...
...
openfold/model/primitives.py
View file @
39a6d0e6
...
@@ -13,24 +13,39 @@
...
@@ -13,24 +13,39 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partial
from
functools
import
partial
import
importlib
import
math
import
math
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
import
numpy
as
np
import
numpy
as
np
import
deepspeed
deepspeed_is_installed
=
importlib
.
util
.
find_spec
(
"deepspeed"
)
is
not
None
if
(
deepspeed_is_installed
):
import
deepspeed
fa_is_installed
=
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
not
None
if
(
fa_is_installed
):
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.flash_attention
import
FlashAttention
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_kvpacked_func
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
scipy.stats
import
truncnorm
from
scipy.stats
import
truncnorm
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.chunk_utils
import
_chunk_slice
from
openfold.utils.kernel.attention_core
import
attention_core
from
openfold.utils.kernel.attention_core
import
attention_core
from
openfold.utils.precision_utils
import
is_fp16_enabled
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
_chunk_slice
,
)
)
DEFAULT_LMA_Q_CHUNK_SIZE
=
1024
DEFAULT_LMA_KV_CHUNK_SIZE
=
4096
def
_prod
(
nums
):
def
_prod
(
nums
):
out
=
1
out
=
1
for
n
in
nums
:
for
n
in
nums
:
...
@@ -145,26 +160,26 @@ class Linear(nn.Linear):
...
@@ -145,26 +160,26 @@ class Linear(nn.Linear):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
bias
.
fill_
(
0
)
self
.
bias
.
fill_
(
0
)
if
init_fn
is
not
None
:
with
torch
.
no_grad
():
init_fn
(
self
.
weight
,
self
.
bias
)
if
init_fn
is
not
None
:
else
:
init_fn
(
self
.
weight
,
self
.
bias
)
if
init
==
"default"
:
lecun_normal_init_
(
self
.
weight
)
elif
init
==
"relu"
:
he_normal_init_
(
self
.
weight
)
elif
init
==
"glorot"
:
glorot_uniform_init_
(
self
.
weight
)
elif
init
==
"gating"
:
gating_init_
(
self
.
weight
)
if
bias
:
with
torch
.
no_grad
():
self
.
bias
.
fill_
(
1.0
)
elif
init
==
"normal"
:
normal_init_
(
self
.
weight
)
elif
init
==
"final"
:
final_init_
(
self
.
weight
)
else
:
else
:
raise
ValueError
(
"Invalid init string."
)
if
init
==
"default"
:
lecun_normal_init_
(
self
.
weight
)
elif
init
==
"relu"
:
he_normal_init_
(
self
.
weight
)
elif
init
==
"glorot"
:
glorot_uniform_init_
(
self
.
weight
)
elif
init
==
"gating"
:
gating_init_
(
self
.
weight
)
if
bias
:
self
.
bias
.
fill_
(
1.0
)
elif
init
==
"normal"
:
normal_init_
(
self
.
weight
)
elif
init
==
"final"
:
final_init_
(
self
.
weight
)
else
:
raise
ValueError
(
"Invalid init string."
)
class
LayerNorm
(
nn
.
Module
):
class
LayerNorm
(
nn
.
Module
):
...
@@ -179,7 +194,11 @@ class LayerNorm(nn.Module):
...
@@ -179,7 +194,11 @@ class LayerNorm(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
d
=
x
.
dtype
d
=
x
.
dtype
if
(
d
is
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
deepspeed_is_initialized
=
(
deepspeed_is_installed
and
deepspeed
.
utils
.
is_initialized
()
)
if
(
d
is
torch
.
bfloat16
and
not
deepspeed_is_initialized
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
out
=
nn
.
functional
.
layer_norm
(
out
=
nn
.
functional
.
layer_norm
(
x
,
x
,
...
@@ -207,7 +226,11 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
...
@@ -207,7 +226,11 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
type bfloat16
type bfloat16
"""
"""
d
=
t
.
dtype
d
=
t
.
dtype
if
(
d
is
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
deepspeed_is_initialized
=
(
deepspeed_is_installed
and
deepspeed
.
utils
.
is_initialized
()
)
if
(
d
is
torch
.
bfloat16
and
not
deepspeed_is_initialized
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
else
:
else
:
...
@@ -403,8 +426,10 @@ class Attention(nn.Module):
...
@@ -403,8 +426,10 @@ class Attention(nn.Module):
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
q_chunk_size
:
Optional
[
int
]
=
None
,
lma_q_chunk_size
:
int
=
DEFAULT_LMA_Q_CHUNK_SIZE
,
kv_chunk_size
:
Optional
[
int
]
=
None
,
lma_kv_chunk_size
:
int
=
DEFAULT_LMA_KV_CHUNK_SIZE
,
use_flash
:
bool
=
False
,
flash_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -423,29 +448,41 @@ class Attention(nn.Module):
...
@@ -423,29 +448,41 @@ class Attention(nn.Module):
Whether to use low-memory attention (Staats & Rabe 2021). If
Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch
none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
implementation is used instead
q_chunk_size:
lma_
q_chunk_size:
Query chunk size (for LMA)
Query chunk size (for LMA)
kv_chunk_size:
lma_
kv_chunk_size:
Key/Value chunk size (for LMA)
Key/Value chunk size (for LMA)
Returns
Returns
[*, Q, C_q] attention update
[*, Q, C_q] attention update
"""
"""
if
(
biases
is
None
):
if
(
use_lma
and
(
lma_q_chunk_size
is
None
or
lma_kv_chunk_size
is
None
)):
biases
=
[]
if
(
use_lma
and
(
q_chunk_size
is
None
or
kv_chunk_size
is
None
)):
raise
ValueError
(
raise
ValueError
(
"If use_lma is specified, q_chunk_size and
kv_chunk_size must
"
"If use_lma is specified,
lma_
q_chunk_size and "
"be provided"
"
lma_kv_chunk_size must
be provided"
)
)
if
(
use_memory_efficient_kernel
and
use_lma
):
if
(
use_flash
and
biases
is
not
None
):
raise
ValueError
(
raise
ValueError
(
"Choose one of use_memory_efficient_kernel and use_lma"
"use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead"
)
)
attn_options
=
[
use_memory_efficient_kernel
,
use_lma
,
use_flash
]
if
(
sum
(
attn_options
)
>
1
):
raise
ValueError
(
"Choose at most one alternative attention algorithm"
)
if
(
biases
is
None
):
biases
=
[]
# [*, H, Q/K, C_hidden]
# [*, H, Q/K, C_hidden]
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
# [*, Q, H, C_hidden]
# [*, Q, H, C_hidden]
if
is_fp16_enabled
():
use_memory_efficient_kernel
=
False
if
(
use_memory_efficient_kernel
):
if
(
use_memory_efficient_kernel
):
if
(
len
(
biases
)
>
2
):
if
(
len
(
biases
)
>
2
):
raise
ValueError
(
raise
ValueError
(
...
@@ -459,7 +496,10 @@ class Attention(nn.Module):
...
@@ -459,7 +496,10 @@ class Attention(nn.Module):
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
for
b
in
biases
for
b
in
biases
]
]
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
o
=
_lma
(
q
,
k
,
v
,
biases
,
lma_q_chunk_size
,
lma_kv_chunk_size
)
o
=
o
.
transpose
(
-
2
,
-
3
)
elif
(
use_flash
):
o
=
_flash_attn
(
q
,
k
,
v
,
flash_mask
)
else
:
else
:
o
=
_attention
(
q
,
k
,
v
,
biases
)
o
=
_attention
(
q
,
k
,
v
,
biases
)
o
=
o
.
transpose
(
-
2
,
-
3
)
o
=
o
.
transpose
(
-
2
,
-
3
)
...
@@ -494,7 +534,11 @@ class GlobalAttention(nn.Module):
...
@@ -494,7 +534,11 @@ class GlobalAttention(nn.Module):
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
sigmoid
=
nn
.
Sigmoid
()
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
# [*, N_res, C_in]
# [*, N_res, C_in]
q
=
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
q
=
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
...
@@ -511,20 +555,30 @@ class GlobalAttention(nn.Module):
...
@@ -511,20 +555,30 @@ class GlobalAttention(nn.Module):
k
=
self
.
linear_k
(
m
)
k
=
self
.
linear_k
(
m
)
v
=
self
.
linear_v
(
m
)
v
=
self
.
linear_v
(
m
)
# [*, N_res, H, N_seq]
a
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
+=
bias
if
(
not
use_lma
):
a
=
softmax_no_cast
(
a
)
# [*, N_res, H, N_seq]
a
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
)
a
+=
bias
a
=
softmax_no_cast
(
a
)
# [*, N_res, H, C_hidden]
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
o
=
torch
.
matmul
(
a
,
a
,
v
,
v
,
)
)
else
:
o
=
_lma
(
q
,
k
,
v
,
[
bias
],
DEFAULT_LMA_Q_CHUNK_SIZE
,
DEFAULT_LMA_KV_CHUNK_SIZE
)
# [*, N_res, N_seq, C_hidden]
# [*, N_res, N_seq, C_hidden]
g
=
self
.
sigmoid
(
self
.
linear_g
(
m
))
g
=
self
.
sigmoid
(
self
.
linear_g
(
m
))
...
@@ -552,12 +606,12 @@ def _lma(
...
@@ -552,12 +606,12 @@ def _lma(
q_chunk_size
:
int
,
q_chunk_size
:
int
,
kv_chunk_size
:
int
,
kv_chunk_size
:
int
,
):
):
no_q
,
no_kv
=
q
.
shape
[
-
3
],
k
.
shape
[
-
3
]
no_q
,
no_kv
=
q
.
shape
[
-
2
],
k
.
shape
[
-
2
]
# [*,
Q
,
H
, C_hidden]
# [*,
H
,
Q
, C_hidden]
o
=
q
.
new_zeros
(
q
.
shape
)
o
=
q
.
new_zeros
(
q
.
shape
)
for
q_s
in
range
(
0
,
no_q
,
q_chunk_size
):
for
q_s
in
range
(
0
,
no_q
,
q_chunk_size
):
q_chunk
=
q
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
q_chunk
=
q
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
large_bias_chunks
=
[
large_bias_chunks
=
[
b
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
for
b
in
biases
b
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
for
b
in
biases
]
]
...
@@ -566,24 +620,22 @@ def _lma(
...
@@ -566,24 +620,22 @@ def _lma(
weights
=
[]
weights
=
[]
values
=
[]
values
=
[]
for
kv_s
in
range
(
0
,
no_kv
,
kv_chunk_size
):
for
kv_s
in
range
(
0
,
no_kv
,
kv_chunk_size
):
k_chunk
=
k
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
k_chunk
=
k
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:]
v_chunk
=
v
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
v_chunk
=
v
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:]
small_bias_chunks
=
[
small_bias_chunks
=
[
b
[...,
kv_s
:
kv_s
+
kv_chunk_size
]
for
b
in
large_bias_chunks
b
[...,
kv_s
:
kv_s
+
kv_chunk_size
]
for
b
in
large_bias_chunks
]
]
a
=
torch
.
einsum
(
a
=
torch
.
einsum
(
"...
q
hd,...
k
hd->...hqk"
,
q_chunk
,
k_chunk
,
"...h
q
d,...h
k
d->...hqk"
,
q_chunk
,
k_chunk
,
)
)
for
b
in
small_bias_chunks
:
for
b
in
small_bias_chunks
:
a
+=
b
a
+=
b
a
=
a
.
transpose
(
-
2
,
-
3
)
max_a
=
torch
.
max
(
a
,
dim
=-
1
,
keepdim
=
True
)[
0
]
max_a
=
torch
.
max
(
a
,
dim
=-
1
,
keepdim
=
True
)[
0
]
exp_a
=
torch
.
exp
(
a
-
max_a
)
exp_a
=
torch
.
exp
(
a
-
max_a
)
exp_v
=
torch
.
einsum
(
"...
v
hf,...
q
hv->...
q
hf"
,
v_chunk
,
exp_a
)
exp_v
=
torch
.
einsum
(
"...h
v
f,...h
q
v->...h
q
f"
,
v_chunk
,
exp_a
)
maxes
.
append
(
max_a
.
detach
().
squeeze
(
-
1
))
maxes
.
append
(
max_a
.
detach
().
squeeze
(
-
1
))
weights
.
append
(
torch
.
sum
(
exp_a
,
dim
=-
1
))
weights
.
append
(
torch
.
sum
(
exp_a
,
dim
=-
1
))
...
@@ -595,14 +647,80 @@ def _lma(
...
@@ -595,14 +647,80 @@ def _lma(
global_max
=
torch
.
max
(
chunk_max
,
dim
=-
3
,
keepdim
=
True
)[
0
]
global_max
=
torch
.
max
(
chunk_max
,
dim
=-
3
,
keepdim
=
True
)[
0
]
max_diffs
=
torch
.
exp
(
chunk_max
-
global_max
)
max_diffs
=
torch
.
exp
(
chunk_max
-
global_max
)
chunk_values
*
=
max_diffs
.
unsqueeze
(
-
1
)
chunk_values
=
chunk_values
*
max_diffs
.
unsqueeze
(
-
1
)
chunk_weights
*
=
max_diffs
chunk_weights
=
chunk_weights
*
max_diffs
all_values
=
torch
.
sum
(
chunk_values
,
dim
=-
4
)
all_values
=
torch
.
sum
(
chunk_values
,
dim
=-
4
)
all_weights
=
torch
.
sum
(
chunk_weights
.
unsqueeze
(
-
1
),
dim
=-
4
)
all_weights
=
torch
.
sum
(
chunk_weights
.
unsqueeze
(
-
1
),
dim
=-
4
)
q_chunk_out
=
all_values
/
all_weights
q_chunk_out
=
all_values
/
all_weights
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
=
q_chunk_out
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
=
q_chunk_out
return
o
return
o
@
torch
.
jit
.
ignore
def
_flash_attn
(
q
,
k
,
v
,
kv_mask
):
if
(
not
fa_is_installed
):
raise
ValueError
(
"_flash_attn requires that FlashAttention be installed"
)
batch_dims
=
q
.
shape
[:
-
3
]
no_heads
,
n
,
c
=
q
.
shape
[
-
3
:]
dtype
=
q
.
dtype
q
=
q
.
half
()
k
=
k
.
half
()
v
=
v
.
half
()
kv_mask
=
kv_mask
.
half
()
# [*, B, N, H, C]
q
=
q
.
transpose
(
-
2
,
-
3
)
k
=
k
.
transpose
(
-
2
,
-
3
)
v
=
v
.
transpose
(
-
2
,
-
3
)
# [B_flat, N, H, C]
q
=
q
.
reshape
(
-
1
,
*
q
.
shape
[
-
3
:])
k
=
k
.
reshape
(
-
1
,
*
k
.
shape
[
-
3
:])
v
=
v
.
reshape
(
-
1
,
*
v
.
shape
[
-
3
:])
# Flattened batch size
batch_size
=
q
.
shape
[
0
]
# [B_flat * N, H, C]
q
=
q
.
reshape
(
-
1
,
*
q
.
shape
[
-
2
:])
q_max_s
=
n
q_cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
n
,
step
=
n
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
# [B_flat, N, 2, H, C]
kv
=
torch
.
stack
([
k
,
v
],
dim
=-
3
)
kv_shape
=
kv
.
shape
# [B_flat, N, 2 * H * C]
kv
=
kv
.
reshape
(
*
kv
.
shape
[:
-
3
],
-
1
)
kv_unpad
,
_
,
kv_cu_seqlens
,
kv_max_s
=
unpad_input
(
kv
,
kv_mask
)
kv_unpad
=
kv_unpad
.
reshape
(
-
1
,
*
kv_shape
[
-
3
:])
out
=
flash_attn_unpadded_kvpacked_func
(
q
,
kv_unpad
,
q_cu_seqlens
,
kv_cu_seqlens
,
q_max_s
,
kv_max_s
,
dropout_p
=
0.
,
softmax_scale
=
1.
,
# q has been scaled already
)
# [*, B, N, H, C]
out
=
out
.
reshape
(
*
batch_dims
,
n
,
no_heads
,
c
)
out
=
out
.
to
(
dtype
=
dtype
)
return
out
openfold/model/structure_module.py
View file @
39a6d0e6
...
@@ -12,11 +12,15 @@
...
@@ -12,11 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
reduce
import
importlib
import
math
import
math
import
sys
from
operator
import
mul
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Sequence
,
Union
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
ipa_point_weights_init_
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
ipa_point_weights_init_
from
openfold.np.residue_constants
import
(
from
openfold.np.residue_constants
import
(
...
@@ -27,11 +31,12 @@ from openfold.np.residue_constants import (
...
@@ -27,11 +31,12 @@ from openfold.np.residue_constants import (
)
)
from
openfold.utils.geometry.quat_rigid
import
QuatRigid
from
openfold.utils.geometry.quat_rigid
import
QuatRigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.vector
import
Vec3Array
from
openfold.utils.geometry.vector
import
Vec3Array
,
square_euclidean_distance
from
openfold.utils.feats
import
(
from
openfold.utils.feats
import
(
frames_and_literature_positions_to_atom14_pos
,
frames_and_literature_positions_to_atom14_pos
,
torsion_angles_to_frames
,
torsion_angles_to_frames
,
)
)
from
openfold.utils.precision_utils
import
is_fp16_enabled
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
dict_multimap
,
dict_multimap
,
...
@@ -39,6 +44,8 @@ from openfold.utils.tensor_utils import (
...
@@ -39,6 +44,8 @@ from openfold.utils.tensor_utils import (
flatten_final_dims
,
flatten_final_dims
,
)
)
attn_core_inplace_cuda
=
importlib
.
import_module
(
"attn_core_inplace_cuda"
)
class
AngleResnetBlock
(
nn
.
Module
):
class
AngleResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
c_hidden
):
def
__init__
(
self
,
c_hidden
):
...
@@ -164,6 +171,7 @@ class PointProjection(nn.Module):
...
@@ -164,6 +171,7 @@ class PointProjection(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
return_local_points
=
return_local_points
self
.
return_local_points
=
return_local_points
self
.
no_heads
=
no_heads
self
.
no_heads
=
no_heads
self
.
num_points
=
num_points
self
.
linear
=
Linear
(
c_hidden
,
no_heads
*
3
*
num_points
)
self
.
linear
=
Linear
(
c_hidden
,
no_heads
*
3
*
num_points
)
...
@@ -173,22 +181,30 @@ class PointProjection(nn.Module):
...
@@ -173,22 +181,30 @@ class PointProjection(nn.Module):
)
->
Union
[
Vec3Array
,
Tuple
[
Vec3Array
,
Vec3Array
],
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
)
->
Union
[
Vec3Array
,
Tuple
[
Vec3Array
,
Vec3Array
],
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
# TODO: Needs to run in high precision during training
# TODO: Needs to run in high precision during training
points_local
=
self
.
linear
(
activations
)
points_local
=
self
.
linear
(
activations
)
points_local
=
points_local
.
reshape
(
*
points_local
.
shape
[:
-
1
],
if
isinstance
(
rigids
,
Rigid3Array
):
self
.
no_heads
,
points_local
=
points_local
.
reshape
(
-
1
,
*
points_local
.
shape
[:
-
1
],
)
self
.
no_heads
,
-
1
,
)
points_local
=
torch
.
split
(
points_local
=
torch
.
split
(
points_local
,
points_local
.
shape
[
-
1
]
//
3
,
dim
=-
1
points_local
,
points_local
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
)
points_local
=
torch
.
stack
(
points_local
,
dim
=-
1
)
points_local
=
torch
.
stack
(
points_local
,
dim
=-
1
)
if
not
isinstance
(
rigids
,
Rigid3Array
):
points_local
=
points_local
.
reshape
(
*
points_local
.
shape
[:
-
2
],
self
.
no_heads
,
-
1
,
3
)
points_global
=
rigids
[...,
None
,
None
].
apply
(
points_local
)
points_global
=
rigids
[...,
None
,
None
].
apply
(
points_local
)
if
(
self
.
return_local_points
):
if
(
self
.
return_local_points
):
return
points_global
,
points_local
return
points_global
,
points_local
return
points_global
return
points_global
class
InvariantPointAttention
(
nn
.
Module
):
class
InvariantPointAttention
(
nn
.
Module
):
...
@@ -242,8 +258,8 @@ class InvariantPointAttention(nn.Module):
...
@@ -242,8 +258,8 @@ class InvariantPointAttention(nn.Module):
self
.
linear_q
=
Linear
(
self
.
c_s
,
hc
,
bias
=
(
not
is_multimer
))
self
.
linear_q
=
Linear
(
self
.
c_s
,
hc
,
bias
=
(
not
is_multimer
))
self
.
linear_q_points
=
PointProjection
(
self
.
linear_q_points
=
PointProjection
(
self
.
c_s
,
self
.
c_s
,
self
.
no_qk_points
,
self
.
no_qk_points
,
self
.
no_heads
self
.
no_heads
)
)
...
@@ -288,6 +304,9 @@ class InvariantPointAttention(nn.Module):
...
@@ -288,6 +304,9 @@ class InvariantPointAttention(nn.Module):
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
r
:
Union
[
Rigid
,
Rigid3Array
],
r
:
Union
[
Rigid
,
Rigid3Array
],
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
inplace_safe
:
bool
=
False
,
_offload_inference
:
bool
=
False
,
_z_reference_list
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -302,6 +321,11 @@ class InvariantPointAttention(nn.Module):
...
@@ -302,6 +321,11 @@ class InvariantPointAttention(nn.Module):
Returns:
Returns:
[*, N_res, C_s] single representation update
[*, N_res, C_s] single representation update
"""
"""
if
(
_offload_inference
and
inplace_safe
):
z
=
_z_reference_list
else
:
z
=
[
z
]
#######################################
#######################################
# Generate scalar and point activations
# Generate scalar and point activations
#######################################
#######################################
...
@@ -312,7 +336,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -312,7 +336,7 @@ class InvariantPointAttention(nn.Module):
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, H, P_qk]
# [*, N_res, H, P_qk]
q_pts
=
self
.
linear_q_points
(
s
,
r
)
q_pts
=
self
.
linear_q_points
(
s
,
r
)
# The following two blocks are equivalent
# The following two blocks are equivalent
# They're separated only to preserve compatibility with old AF weights
# They're separated only to preserve compatibility with old AF weights
...
@@ -351,13 +375,25 @@ class InvariantPointAttention(nn.Module):
...
@@ -351,13 +375,25 @@ class InvariantPointAttention(nn.Module):
# Compute attention scores
# Compute attention scores
##########################
##########################
# [*, N_res, N_res, H]
# [*, N_res, N_res, H]
b
=
self
.
linear_b
(
z
)
b
=
self
.
linear_b
(
z
[
0
])
if
(
_offload_inference
):
assert
(
sys
.
getrefcount
(
z
[
0
])
==
2
)
z
[
0
]
=
z
[
0
].
cpu
()
# [*, H, N_res, N_res]
# [*, H, N_res, N_res]
a
=
torch
.
matmul
(
if
(
is_fp16_enabled
()):
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
a
=
torch
.
matmul
(
)
permute_final_dims
(
q
.
float
(),
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
.
float
(),
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
)
else
:
a
=
torch
.
matmul
(
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
)
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
...
@@ -369,7 +405,12 @@ class InvariantPointAttention(nn.Module):
...
@@ -369,7 +405,12 @@ class InvariantPointAttention(nn.Module):
pt_att
=
sum
([
c
**
2
for
c
in
pt_att
])
pt_att
=
sum
([
c
**
2
for
c
in
pt_att
])
else
:
else
:
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
pt_att
**
2
if
(
inplace_safe
):
pt_att
*=
pt_att
else
:
pt_att
=
pt_att
**
2
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
...
@@ -378,7 +419,11 @@ class InvariantPointAttention(nn.Module):
...
@@ -378,7 +419,11 @@ class InvariantPointAttention(nn.Module):
head_weights
=
head_weights
*
math
.
sqrt
(
head_weights
=
head_weights
*
math
.
sqrt
(
1.0
/
(
3
*
(
self
.
no_qk_points
*
9.0
/
2
))
1.0
/
(
3
*
(
self
.
no_qk_points
*
9.0
/
2
))
)
)
pt_att
=
pt_att
*
head_weights
if
(
inplace_safe
):
pt_att
*=
head_weights
else
:
pt_att
=
pt_att
*
head_weights
# [*, N_res, N_res, H]
# [*, N_res, N_res, H]
pt_att
=
torch
.
sum
(
pt_att
,
dim
=-
1
)
*
(
-
0.5
)
pt_att
=
torch
.
sum
(
pt_att
,
dim
=-
1
)
*
(
-
0.5
)
...
@@ -388,9 +433,21 @@ class InvariantPointAttention(nn.Module):
...
@@ -388,9 +433,21 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
# [*, H, N_res, N_res]
pt_att
=
permute_final_dims
(
pt_att
,
(
2
,
0
,
1
))
pt_att
=
permute_final_dims
(
pt_att
,
(
2
,
0
,
1
))
a
=
a
+
pt_att
a
=
a
+
square_mask
.
unsqueeze
(
-
3
)
if
(
inplace_safe
):
a
=
self
.
softmax
(
a
)
a
+=
pt_att
del
pt_att
a
+=
square_mask
.
unsqueeze
(
-
3
)
# in-place softmax
attn_core_inplace_cuda
.
forward_
(
a
,
reduce
(
mul
,
a
.
shape
[:
-
1
]),
a
.
shape
[
-
1
],
)
else
:
a
=
a
+
pt_att
a
=
a
+
square_mask
.
unsqueeze
(
-
3
)
a
=
self
.
softmax
(
a
)
################
################
# Compute output
# Compute output
...
@@ -419,13 +476,22 @@ class InvariantPointAttention(nn.Module):
...
@@ -419,13 +476,22 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v]
# [*, N_res, H * P_v]
o_pt_norm
=
o_pt
.
norm
(
self
.
eps
)
o_pt_norm
=
o_pt
.
norm
(
self
.
eps
)
else
:
else
:
o_pt
=
torch
.
sum
(
# [*, H, 3, N_res, P_v]
(
if
(
inplace_safe
):
a
[...,
None
,
:,
:,
None
]
v_pts
=
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
o_pt
=
[
),
torch
.
matmul
(
a
,
v
.
to
(
a
.
dtype
))
dim
=-
2
,
for
v
in
torch
.
unbind
(
v_pts
,
dim
=-
3
)
)
]
o_pt
=
torch
.
stack
(
o_pt
,
dim
=-
3
)
else
:
o_pt
=
torch
.
sum
(
(
a
[...,
None
,
:,
:,
None
]
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
),
dim
=-
2
,
)
# [*, N_res, H, P_v, 3]
# [*, N_res, H, P_v, 3]
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
...
@@ -440,8 +506,11 @@ class InvariantPointAttention(nn.Module):
...
@@ -440,8 +506,11 @@ class InvariantPointAttention(nn.Module):
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
o_pt
=
torch
.
unbind
(
o_pt
,
dim
=-
1
)
o_pt
=
torch
.
unbind
(
o_pt
,
dim
=-
1
)
if
(
_offload_inference
):
z
[
0
]
=
z
[
0
].
to
(
o_pt
.
device
)
# [*, N_res, H, C_z]
# [*, N_res, H, C_z]
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
))
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
[
0
]
.
to
(
dtype
=
a
.
dtype
))
# [*, N_res, H * C_z]
# [*, N_res, H * C_z]
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
o_pair
=
flatten_final_dims
(
o_pair
,
2
)
...
@@ -450,7 +519,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -450,7 +519,7 @@ class InvariantPointAttention(nn.Module):
s
=
self
.
linear_out
(
s
=
self
.
linear_out
(
torch
.
cat
(
torch
.
cat
(
(
o
,
*
o_pt
,
o_pt_norm
,
o_pair
),
dim
=-
1
(
o
,
*
o_pt
,
o_pt_norm
,
o_pair
),
dim
=-
1
).
to
(
dtype
=
z
.
dtype
)
).
to
(
dtype
=
z
[
0
]
.
dtype
)
)
)
return
s
return
s
...
@@ -611,11 +680,11 @@ class StructureModule(nn.Module):
...
@@ -611,11 +680,11 @@ class StructureModule(nn.Module):
self
.
inf
=
inf
self
.
inf
=
inf
self
.
is_multimer
=
is_multimer
self
.
is_multimer
=
is_multimer
#
T
o be lazily initialized later
#
Buffers t
o be lazily initialized later
self
.
default_frames
=
None
#
self.default_frames
self
.
group_idx
=
None
#
self.group_idx
self
.
atom_mask
=
None
#
self.atom_mask
self
.
lit_positions
=
None
#
self.lit_positions
self
.
layer_norm_s
=
LayerNorm
(
self
.
c_s
)
self
.
layer_norm_s
=
LayerNorm
(
self
.
c_s
)
self
.
layer_norm_z
=
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_z
=
LayerNorm
(
self
.
c_z
)
...
@@ -655,62 +724,32 @@ class StructureModule(nn.Module):
...
@@ -655,62 +724,32 @@ class StructureModule(nn.Module):
self
.
no_angles
,
self
.
no_angles
,
self
.
epsilon
,
self
.
epsilon
,
)
)
def
_init_residue_constants
(
self
,
float_dtype
,
device
):
if
self
.
default_frames
is
None
:
self
.
default_frames
=
torch
.
tensor
(
restype_rigid_group_default_frame
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
group_idx
is
None
:
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
atom_mask
is
None
:
self
.
atom_mask
=
torch
.
tensor
(
restype_atom14_mask
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
lit_positions
is
None
:
self
.
lit_positions
=
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
def
torsion_angles_to_frames
(
self
,
r
,
alpha
,
f
):
def
_forward_monomer
(
# Lazily initialize the residue constants on the correct device
self
,
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
evoformer_output_dict
,
# Separated purely to make testing less annoying
return
torsion_angles_to_frames
(
r
,
alpha
,
f
,
self
.
default_frames
)
def
frames_and_literature_positions_to_atom14_pos
(
self
,
r
,
f
# [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
r
.
dtype
,
r
.
device
)
return
frames_and_literature_positions_to_atom14_pos
(
r
,
f
,
self
.
default_frames
,
self
.
group_idx
,
self
.
atom_mask
,
self
.
lit_positions
,
)
def
_forward_monomer
(
self
,
s
,
z
,
aatype
,
aatype
,
mask
=
None
,
mask
=
None
,
inplace_safe
=
False
,
_offload_inference
=
False
,
):
):
"""
Args:
evoformer_output_dict:
Dictionary containing:
"single":
[*, N_res, C_s] single representation
"pair":
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
s
=
evoformer_output_dict
[
"single"
]
if
mask
is
None
:
if
mask
is
None
:
# [*, N]
# [*, N]
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
...
@@ -719,7 +758,14 @@ class StructureModule(nn.Module):
...
@@ -719,7 +758,14 @@ class StructureModule(nn.Module):
s
=
self
.
layer_norm_s
(
s
)
s
=
self
.
layer_norm_s
(
s
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
self
.
layer_norm_z
(
z
)
z
=
self
.
layer_norm_z
(
evoformer_output_dict
[
"pair"
])
z_reference_list
=
None
if
(
_offload_inference
):
assert
(
sys
.
getrefcount
(
evoformer_output_dict
[
"pair"
])
==
2
)
evoformer_output_dict
[
"pair"
]
=
evoformer_output_dict
[
"pair"
].
cpu
()
z_reference_list
=
[
z
]
z
=
None
# [*, N, C_s]
# [*, N, C_s]
s_initial
=
s
s_initial
=
s
...
@@ -736,11 +782,19 @@ class StructureModule(nn.Module):
...
@@ -736,11 +782,19 @@ class StructureModule(nn.Module):
outputs
=
[]
outputs
=
[]
for
i
in
range
(
self
.
no_blocks
):
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
,
_z_reference_list
=
z_reference_list
)
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
s
=
self
.
transition
(
s
)
# [*, N]
# [*, N]
rigids
=
rigids
.
compose_q_update_vec
(
self
.
bb_update
(
s
))
rigids
=
rigids
.
compose_q_update_vec
(
self
.
bb_update
(
s
))
...
@@ -781,24 +835,35 @@ class StructureModule(nn.Module):
...
@@ -781,24 +835,35 @@ class StructureModule(nn.Module):
"unnormalized_angles"
:
unnormalized_angles
,
"unnormalized_angles"
:
unnormalized_angles
,
"angles"
:
angles
,
"angles"
:
angles
,
"positions"
:
pred_xyz
,
"positions"
:
pred_xyz
,
"states"
:
s
,
}
}
outputs
.
append
(
preds
)
outputs
.
append
(
preds
)
if
i
<
(
self
.
no_blocks
-
1
):
rigids
=
rigids
.
stop_rot_gradient
()
rigids
=
rigids
.
stop_rot_gradient
()
del
z
,
z_reference_list
if
(
_offload_inference
):
evoformer_output_dict
[
"pair"
]
=
(
evoformer_output_dict
[
"pair"
].
to
(
s
.
device
)
)
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
[
"single"
]
=
s
outputs
[
"single"
]
=
s
return
outputs
return
outputs
def
_forward_multimer
(
self
,
def
_forward_multimer
(
s
,
self
,
z
,
evoformer_output_dict
,
aatype
,
aatype
,
mask
=
None
,
mask
=
None
,
inplace_safe
=
False
,
_offload_inference
=
False
,
):
):
s
=
evoformer_output_dict
[
"single"
]
if
mask
is
None
:
if
mask
is
None
:
# [*, N]
# [*, N]
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
...
@@ -807,7 +872,14 @@ class StructureModule(nn.Module):
...
@@ -807,7 +872,14 @@ class StructureModule(nn.Module):
s
=
self
.
layer_norm_s
(
s
)
s
=
self
.
layer_norm_s
(
s
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
self
.
layer_norm_z
(
z
)
z
=
self
.
layer_norm_z
(
evoformer_output_dict
[
"pair"
])
z_reference_list
=
None
if
(
_offload_inference
):
assert
(
sys
.
getrefcount
(
evoformer_output_dict
[
"pair"
])
==
2
)
evoformer_output_dict
[
"pair"
]
=
evoformer_output_dict
[
"pair"
].
cpu
()
z_reference_list
=
[
z
]
z
=
None
# [*, N, C_s]
# [*, N, C_s]
s_initial
=
s
s_initial
=
s
...
@@ -821,7 +893,15 @@ class StructureModule(nn.Module):
...
@@ -821,7 +893,15 @@ class StructureModule(nn.Module):
outputs
=
[]
outputs
=
[]
for
i
in
range
(
self
.
no_blocks
):
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
,
_z_reference_list
=
z_reference_list
)
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
s
=
self
.
transition
(
s
)
...
@@ -848,13 +928,19 @@ class StructureModule(nn.Module):
...
@@ -848,13 +928,19 @@ class StructureModule(nn.Module):
"sidechain_frames"
:
all_frames_to_global
.
to_tensor_4x4
(),
"sidechain_frames"
:
all_frames_to_global
.
to_tensor_4x4
(),
"unnormalized_angles"
:
unnormalized_angles
,
"unnormalized_angles"
:
unnormalized_angles
,
"angles"
:
angles
,
"angles"
:
angles
,
"positions"
:
pred_xyz
.
to_tensor
()
,
"positions"
:
pred_xyz
,
}
}
outputs
.
append
(
preds
)
outputs
.
append
(
preds
)
if
i
<
(
self
.
no_blocks
-
1
):
rigids
=
rigids
.
stop_rot_gradient
()
rigids
=
rigids
.
stop_rot_gradient
()
del
z
,
z_reference_list
if
(
_offload_inference
):
evoformer_output_dict
[
"pair"
]
=
(
evoformer_output_dict
[
"pair"
].
to
(
s
.
device
)
)
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
[
"single"
]
=
s
outputs
[
"single"
]
=
s
...
@@ -863,10 +949,11 @@ class StructureModule(nn.Module):
...
@@ -863,10 +949,11 @@ class StructureModule(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
s
,
evoformer_output_dict
,
z
,
aatype
,
aatype
,
mask
=
None
,
mask
=
None
,
inplace_safe
=
False
,
_offload_inference
=
False
,
):
):
"""
"""
Args:
Args:
...
@@ -882,8 +969,73 @@ class StructureModule(nn.Module):
...
@@ -882,8 +969,73 @@ class StructureModule(nn.Module):
A dictionary of outputs
A dictionary of outputs
"""
"""
if
(
self
.
is_multimer
):
if
(
self
.
is_multimer
):
outputs
=
self
.
_forward_multimer
(
s
,
z
,
aatype
,
mask
)
outputs
=
self
.
_forward_multimer
(
evoformer_output_dict
,
aatype
,
mask
,
inplace_safe
,
_offload_inference
)
else
:
else
:
outputs
=
self
.
_forward_monomer
(
s
,
z
,
aatype
,
mask
)
outputs
=
self
.
_forward_monomer
(
evoformer_output_dict
,
aatype
,
mask
,
inplace_safe
,
_offload_inference
)
return
outputs
return
outputs
def
_init_residue_constants
(
self
,
float_dtype
,
device
):
if
not
hasattr
(
self
,
"default_frames"
):
self
.
register_buffer
(
"default_frames"
,
torch
.
tensor
(
restype_rigid_group_default_frame
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
),
persistent
=
False
,
)
if
not
hasattr
(
self
,
"group_idx"
):
self
.
register_buffer
(
"group_idx"
,
torch
.
tensor
(
restype_atom14_to_rigid_group
,
device
=
device
,
requires_grad
=
False
,
),
persistent
=
False
,
)
if
not
hasattr
(
self
,
"atom_mask"
):
self
.
register_buffer
(
"atom_mask"
,
torch
.
tensor
(
restype_atom14_mask
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
),
persistent
=
False
,
)
if
not
hasattr
(
self
,
"lit_positions"
):
self
.
register_buffer
(
"lit_positions"
,
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
),
persistent
=
False
,
)
def
torsion_angles_to_frames
(
self
,
r
,
alpha
,
f
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
# Separated purely to make testing less annoying
return
torsion_angles_to_frames
(
r
,
alpha
,
f
,
self
.
default_frames
)
def
frames_and_literature_positions_to_atom14_pos
(
self
,
r
,
f
# [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
r
.
dtype
,
r
.
device
)
return
frames_and_literature_positions_to_atom14_pos
(
r
,
f
,
self
.
default_frames
,
self
.
group_idx
,
self
.
atom_mask
,
self
.
lit_positions
,
)
openfold/model/template.py
View file @
39a6d0e6
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
from
functools
import
partial
from
functools
import
partial
import
math
import
math
import
sys
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
import
torch
import
torch
...
@@ -34,10 +35,19 @@ from openfold.model.triangular_multiplicative_update import (
...
@@ -34,10 +35,19 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming
,
TriangleMultiplicationIncoming
,
)
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
from
openfold.utils.checkpointing
import
checkpoint_blocks
from
openfold.utils.
tensor
_utils
import
(
from
openfold.utils.
chunk
_utils
import
(
chunk_layer
,
chunk_layer
,
ChunkSizeTuner
,
)
from
openfold.utils.feats
import
(
build_template_angle_feat
,
build_template_pair_feat
,
)
from
openfold.utils.tensor_utils
import
(
add
,
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
tensor_tree_map
,
)
)
...
@@ -77,6 +87,7 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -77,6 +87,7 @@ class TemplatePointwiseAttention(nn.Module):
t
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
mha_inputs
=
{
mha_inputs
=
{
"q_x"
:
z
,
"q_x"
:
z
,
...
@@ -84,7 +95,7 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -84,7 +95,7 @@ class TemplatePointwiseAttention(nn.Module):
"biases"
:
biases
,
"biases"
:
biases
,
}
}
return
chunk_layer
(
return
chunk_layer
(
self
.
mha
,
partial
(
self
.
mha
,
use_lma
=
use_lma
),
mha_inputs
,
mha_inputs
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
...
@@ -95,7 +106,9 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -95,7 +106,9 @@ class TemplatePointwiseAttention(nn.Module):
t
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
template_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
template_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
# This module suffers greatly from a small chunk size
chunk_size
:
Optional
[
int
]
=
256
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -121,10 +134,10 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -121,10 +134,10 @@ class TemplatePointwiseAttention(nn.Module):
# [*, N_res, N_res, 1, C_z]
# [*, N_res, N_res, 1, C_z]
biases
=
[
bias
]
biases
=
[
bias
]
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
and
not
self
.
training
:
z
=
self
.
_chunk
(
z
,
t
,
biases
,
chunk_size
)
z
=
self
.
_chunk
(
z
,
t
,
biases
,
chunk_size
,
use_lma
=
use_lma
)
else
:
else
:
z
=
self
.
mha
(
q_x
=
z
,
kv_x
=
t
,
biases
=
biases
)
z
=
self
.
mha
(
q_x
=
z
,
kv_x
=
t
,
biases
=
biases
,
use_lma
=
use_lma
)
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
z
=
z
.
squeeze
(
-
2
)
z
=
z
.
squeeze
(
-
2
)
...
@@ -186,74 +199,118 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -186,74 +199,118 @@ class TemplatePairStackBlock(nn.Module):
self
.
pair_transition_n
,
self
.
pair_transition_n
,
)
)
def
tri_att_start_end
(
self
,
single
,
single_mask
,
chunk_size
):
def
tri_att_start_end
(
self
,
single
,
_attn_chunk_size
,
single_mask
,
use_lma
,
inplace_safe
):
single
=
single
+
self
.
dropout_row
(
single
=
add
(
single
,
self
.
tri_att_start
(
self
.
dropout_row
(
single
,
self
.
tri_att_start
(
chunk_size
=
chunk_size
,
single
,
mask
=
single_mask
chunk_size
=
_attn_chunk_size
,
)
mask
=
single_mask
,
)
use_lma
=
use_lma
,
single
=
single
+
self
.
dropout_col
(
inplace_safe
=
inplace_safe
,
self
.
tri_att_end
(
)
single
,
),
chunk_size
=
chunk_size
,
inplace_safe
,
mask
=
single_mask
)
)
)
single
=
add
(
single
,
self
.
dropout_col
(
self
.
tri_att_end
(
single
,
chunk_size
=
_attn_chunk_size
,
mask
=
single_mask
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
),
inplace_safe
,
)
return
single
return
single
def
tri_mul_out_in
(
self
,
single
,
single_mask
):
def
tri_mul_out_in
(
self
,
single
,
single_mask
,
inplace_safe
):
single
=
single
+
self
.
dropout_row
(
tmu_update
=
self
.
tri_mul_out
(
s
elf
.
tri_mul_out
(
s
ingle
,
single
,
mask
=
single
_mask
,
mask
=
single_mask
inplace_safe
=
inplace_safe
,
)
_add_with_inplace
=
True
,
)
)
single
=
single
+
self
.
dropout_row
(
if
(
not
inplace_safe
):
self
.
tri_mul_in
(
single
=
single
+
self
.
dropout_row
(
tmu_update
)
single
,
else
:
mask
=
single_mask
single
=
tmu_update
)
del
tmu_update
tmu_update
=
self
.
tri_mul_in
(
single
,
mask
=
single_mask
,
inplace_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
)
)
if
(
not
inplace_safe
):
single
=
single
+
self
.
dropout_row
(
tmu_update
)
else
:
single
=
tmu_update
del
tmu_update
return
single
return
single
def
forward
(
self
,
def
forward
(
self
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
):
):
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
single_templates
=
[
single_templates
=
[
t
.
unsqueeze
(
-
4
)
for
t
in
torch
.
unbind
(
z
,
dim
=-
4
)
t
.
unsqueeze
(
-
4
)
for
t
in
torch
.
unbind
(
z
,
dim
=-
4
)
]
]
single_templates_masks
=
[
single_templates_masks
=
[
m
.
unsqueeze
(
-
3
)
for
m
in
torch
.
unbind
(
mask
,
dim
=-
3
)
m
.
unsqueeze
(
-
3
)
for
m
in
torch
.
unbind
(
mask
,
dim
=-
3
)
]
]
for
i
in
range
(
len
(
single_templates
)):
for
i
in
range
(
len
(
single_templates
)):
single
=
single_templates
[
i
]
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
single_mask
=
single_templates_masks
[
i
]
if
self
.
tri_mul_first
:
if
self
.
tri_mul_first
:
single
=
self
.
tri_att_start_end
(
single
=
self
.
tri_mul_out_in
(
single
=
single
,
single
=
self
.
tri_att_start_end
(
single
=
self
.
tri_mul_out_in
(
single
=
single
,
single_mask
=
single_mask
),
single_mask
=
single_mask
,
inplace_safe
=
inplace_safe
),
_attn_chunk_size
=
_attn_chunk_size
,
single_mask
=
single_mask
,
single_mask
=
single_mask
,
chunk_size
=
chunk_size
)
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
)
else
:
else
:
single
=
self
.
tri_mul_out_in
(
single
=
self
.
tri_att_start_end
(
single
=
single
,
single
=
self
.
tri_mul_out_in
(
single
=
self
.
tri_att_start_end
(
single
=
single
,
_attn_chunk_size
=
_attn_chunk_size
,
single_mask
=
single_mask
,
single_mask
=
single_mask
,
chunk_size
=
chunk_size
),
use_lma
=
use_lma
,
single_mask
=
single_mask
)
inplace_safe
=
inplace_safe
),
single_mask
=
single_mask
,
single
=
single
+
self
.
pair_transition
(
inplace_safe
=
inplace_safe
)
single
,
mask
=
single_mask
if
_mask_trans
else
None
,
single
=
add
(
single
,
chunk_size
=
chunk_size
,
self
.
pair_transition
(
)
single
,
mask
=
single_mask
if
_mask_trans
else
None
,
single_templates
[
i
]
=
single
chunk_size
=
chunk_size
,
),
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
inplace_safe
,
)
if
(
not
inplace_safe
):
single_templates
[
i
]
=
single
if
(
not
inplace_safe
):
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
return
z
return
z
...
@@ -273,6 +330,7 @@ class TemplatePairStack(nn.Module):
...
@@ -273,6 +330,7 @@ class TemplatePairStack(nn.Module):
dropout_rate
,
dropout_rate
,
tri_mul_first
,
tri_mul_first
,
blocks_per_ckpt
,
blocks_per_ckpt
,
tune_chunk_size
:
bool
=
False
,
inf
=
1e9
,
inf
=
1e9
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -314,11 +372,18 @@ class TemplatePairStack(nn.Module):
...
@@ -314,11 +372,18 @@ class TemplatePairStack(nn.Module):
self
.
layer_norm
=
LayerNorm
(
c_t
)
self
.
layer_norm
=
LayerNorm
(
c_t
)
self
.
tune_chunk_size
=
tune_chunk_size
self
.
chunk_size_tuner
=
None
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
forward
(
def
forward
(
self
,
self
,
t
:
torch
.
tensor
,
t
:
torch
.
tensor
,
mask
:
torch
.
tensor
,
mask
:
torch
.
tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
):
):
"""
"""
...
@@ -335,16 +400,34 @@ class TemplatePairStack(nn.Module):
...
@@ -335,16 +400,34 @@ class TemplatePairStack(nn.Module):
expand_idx
[
-
3
]
=
t
.
shape
[
-
4
]
expand_idx
[
-
3
]
=
t
.
shape
[
-
4
]
mask
=
mask
.
expand
(
*
expand_idx
)
mask
=
mask
.
expand
(
*
expand_idx
)
blocks
=
[
partial
(
b
,
mask
=
mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
assert
(
not
self
.
training
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
(
t
.
clone
(),),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
tuned_chunk_size
,
_attn_chunk_size
=
max
(
chunk_size
,
tuned_chunk_size
//
4
),
)
for
b
in
blocks
]
t
,
=
checkpoint_blocks
(
t
,
=
checkpoint_blocks
(
blocks
=
[
blocks
=
blocks
,
partial
(
b
,
mask
=
mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
],
args
=
(
t
,),
args
=
(
t
,),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
)
...
@@ -352,3 +435,223 @@ class TemplatePairStack(nn.Module):
...
@@ -352,3 +435,223 @@ class TemplatePairStack(nn.Module):
t
=
self
.
layer_norm
(
t
)
t
=
self
.
layer_norm
(
t
)
return
t
return
t
def
embed_templates_offload
(
model
,
batch
,
z
,
pair_mask
,
templ_dim
,
template_chunk_size
=
256
,
inplace_safe
=
False
,
):
"""
Args:
model:
An AlphaFold model object
batch:
An AlphaFold input batch. See documentation of AlphaFold.
z:
A [*, N, N, C_z] pair embedding
pair_mask:
A [*, N, N] pair mask
templ_dim:
The template dimension of the template tensors in batch
template_chunk_size:
Integer value controlling how quickly the offloaded pair embedding
tensor is brought back into GPU memory. In dire straits, can be
lowered to reduce memory consumption of this function even more.
Returns:
A dictionary of template pair and angle embeddings.
A version of the "embed_templates" method of the AlphaFold class that
offloads the large template pair tensor to CPU. Slower but more frugal
with GPU memory than the original. Useful for long-sequence inference.
"""
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds_cpu
=
[]
n
=
z
.
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
).
squeeze
(
templ_dim
),
batch
,
)
# [*, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
use_unit_vector
=
model
.
config
.
template
.
use_unit_vector
,
inf
=
model
.
config
.
template
.
inf
,
eps
=
model
.
config
.
template
.
eps
,
**
model
.
config
.
template
.
distogram
,
).
to
(
z
.
dtype
)
t
=
model
.
template_pair_embedder
(
t
)
# [*, 1, N, N, C_z]
t
=
model
.
template_pair_stack
(
t
.
unsqueeze
(
templ_dim
),
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
model
.
globals
.
chunk_size
,
use_lma
=
model
.
globals
.
use_lma
,
_mask_trans
=
model
.
config
.
_mask_trans
,
)
assert
(
sys
.
getrefcount
(
t
)
==
2
)
pair_embeds_cpu
.
append
(
t
.
cpu
())
del
t
# Preallocate the output tensor
t
=
z
.
new_zeros
(
z
.
shape
)
for
i
in
range
(
0
,
n
,
template_chunk_size
):
pair_chunks
=
[
p
[...,
i
:
i
+
template_chunk_size
,
:,
:]
for
p
in
pair_embeds_cpu
]
pair_chunk
=
torch
.
cat
(
pair_chunks
,
dim
=
templ_dim
).
to
(
device
=
z
.
device
)
z_chunk
=
z
[...,
i
:
i
+
template_chunk_size
,
:,
:]
att_chunk
=
model
.
template_pointwise_att
(
pair_chunk
,
z_chunk
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
use_lma
=
model
.
globals
.
use_lma
,
)
t
[...,
i
:
i
+
template_chunk_size
,
:,
:]
=
att_chunk
del
pair_chunks
if
(
inplace_safe
):
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
else
:
t
*=
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
ret
=
{}
if
model
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
batch
,
)
# [*, N, C_m]
a
=
model
.
template_angle_embedder
(
template_angle_feat
)
ret
[
"template_single_embedding"
]
=
a
ret
.
update
({
"template_pair_embedding"
:
t
})
return
ret
def
embed_templates_average
(
model
,
batch
,
z
,
pair_mask
,
templ_dim
,
templ_group_size
=
2
,
inplace_safe
=
False
,
):
"""
Args:
model:
An AlphaFold model object
batch:
An AlphaFold input batch. See documentation of AlphaFold.
z:
A [*, N, N, C_z] pair embedding
pair_mask:
A [*, N, N] pair mask
templ_dim:
The template dimension of the template tensors in batch
templ_group_size:
Granularity of the approximation. Larger values trade memory for
greater proximity to the original function
Returns:
A dictionary of template pair and angle embeddings.
A memory-efficient approximation of the "embed_templates" method of the
AlphaFold class. Instead of running pointwise attention over pair
embeddings for all of the templates at the same time, it splits templates
into groups of size templ_group_size, computes embeddings for each group
normally, and then averages the group embeddings. In our experiments, this
approximation has a minimal effect on the quality of the resulting
embedding, while its low memory footprint allows the number of templates
to scale almost indefinitely.
"""
# Embed the templates one at a time (with a poor man's vmap)
n
=
z
.
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
out_tensor
=
z
.
new_zeros
(
z
.
shape
)
for
i
in
range
(
0
,
n_templ
,
templ_group_size
):
def
slice_template_tensor
(
t
):
s
=
[
slice
(
None
)
for
_
in
t
.
shape
]
s
[
templ_dim
]
=
slice
(
i
,
i
+
templ_group_size
)
return
t
[
s
]
template_feats
=
tensor_tree_map
(
slice_template_tensor
,
batch
,
)
# [*, N, N, C_t]
t
=
build_template_pair_feat
(
template_feats
,
use_unit_vector
=
model
.
config
.
template
.
use_unit_vector
,
inf
=
model
.
config
.
template
.
inf
,
eps
=
model
.
config
.
template
.
eps
,
**
model
.
config
.
template
.
distogram
,
).
to
(
z
.
dtype
)
# [*, S_t, N, N, C_z]
t
=
model
.
template_pair_embedder
(
t
)
t
=
model
.
template_pair_stack
(
t
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
model
.
globals
.
chunk_size
,
use_lma
=
model
.
globals
.
use_lma
,
_mask_trans
=
model
.
config
.
_mask_trans
,
)
t
=
model
.
template_pointwise_att
(
t
,
z
,
template_mask
=
template_feats
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
use_lma
=
model
.
globals
.
use_lma
,
)
denom
=
math
.
ceil
(
n_templ
/
templ_group_size
)
if
(
inplace_safe
):
t
/=
denom
else
:
t
=
t
/
denom
if
(
inplace_safe
):
out_tensor
+=
t
else
:
out_tensor
=
out_tensor
+
t
del
t
if
(
inplace_safe
):
out_tensor
*=
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
else
:
out_tensor
=
out_tensor
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
ret
=
{}
if
model
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
batch
,
)
# [*, N, C_m]
a
=
model
.
template_angle_embedder
(
template_angle_feat
)
ret
[
"template_single_embedding"
]
=
a
ret
.
update
({
"template_pair_embedding"
:
out_tensor
})
return
ret
openfold/model/triangular_attention.py
View file @
39a6d0e6
...
@@ -21,8 +21,8 @@ import torch
...
@@ -21,8 +21,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
Attention
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
Attention
from
openfold.utils.chunk_utils
import
chunk_layer
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
)
)
...
@@ -30,7 +30,7 @@ from openfold.utils.tensor_utils import (
...
@@ -30,7 +30,7 @@ from openfold.utils.tensor_utils import (
class
TriangleAttention
(
nn
.
Module
):
class
TriangleAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
starting
,
inf
=
1e9
self
,
c_in
,
c_hidden
,
no_heads
,
starting
=
True
,
inf
=
1e9
):
):
"""
"""
Args:
Args:
...
@@ -62,23 +62,36 @@ class TriangleAttention(nn.Module):
...
@@ -62,23 +62,36 @@ class TriangleAttention(nn.Module):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
chunk_size
:
int
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"triangle! triangle!"
mha_inputs
=
{
mha_inputs
=
{
"q_x"
:
x
,
"q_x"
:
x
,
"kv_x"
:
x
,
"kv_x"
:
x
,
"biases"
:
biases
,
"biases"
:
biases
,
}
}
return
chunk_layer
(
return
chunk_layer
(
partial
(
self
.
mha
),
partial
(
self
.
mha
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
),
mha_inputs
,
mha_inputs
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
_out
=
x
if
inplace_safe
else
None
,
)
)
def
forward
(
self
,
def
forward
(
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -86,15 +99,14 @@ class TriangleAttention(nn.Module):
...
@@ -86,15 +99,14 @@ class TriangleAttention(nn.Module):
[*, I, J, C_in] input tensor (e.g. the pair representation)
[*, I, J, C_in] input tensor (e.g. the pair representation)
Returns:
Returns:
[*, I, J, C_in] output tensor
[*, I, J, C_in] output tensor
"""
"""
if
mask
is
None
:
if
mask
is
None
:
# [*, I, J]
# [*, I, J]
mask
=
x
.
new_ones
(
mask
=
x
.
new_ones
(
x
.
shape
[:
-
1
],
x
.
shape
[:
-
1
],
)
)
# Shape annotations assume self.starting. Else, I and J are flipped
if
(
not
self
.
starting
):
if
not
self
.
starting
:
x
=
x
.
transpose
(
-
2
,
-
3
)
x
=
x
.
transpose
(
-
2
,
-
3
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
...
@@ -113,27 +125,35 @@ class TriangleAttention(nn.Module):
...
@@ -113,27 +125,35 @@ class TriangleAttention(nn.Module):
biases
=
[
mask_bias
,
triangle_bias
]
biases
=
[
mask_bias
,
triangle_bias
]
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
)
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
else
:
else
:
x
=
self
.
mha
(
q_x
=
x
,
kv_x
=
x
,
biases
=
biases
)
x
=
self
.
mha
(
q_x
=
x
,
kv_x
=
x
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
)
if
not
self
.
starting
:
if
(
not
self
.
starting
)
:
x
=
x
.
transpose
(
-
2
,
-
3
)
x
=
x
.
transpose
(
-
2
,
-
3
)
return
x
return
x
class
TriangleAttentionStartingNode
(
TriangleAttention
):
# Implements Algorithm 13
"""
TriangleAttentionStartingNode
=
TriangleAttention
Implements Algorithm 13.
"""
__init__
=
partialmethod
(
TriangleAttention
.
__init__
,
starting
=
True
)
class
TriangleAttentionEndingNode
(
TriangleAttention
):
class
TriangleAttentionEndingNode
(
TriangleAttention
):
"""
"""
Implements Algorithm 14.
Implements Algorithm 14.
"""
"""
__init__
=
partialmethod
(
TriangleAttention
.
__init__
,
starting
=
False
)
__init__
=
partialmethod
(
TriangleAttention
.
__init__
,
starting
=
False
)
openfold/model/triangular_multiplicative_update.py
View file @
39a6d0e6
...
@@ -20,7 +20,9 @@ import torch
...
@@ -20,7 +20,9 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.tensor_utils
import
permute_final_dims
from
openfold.utils.chunk_utils
import
chunk_layer
from
openfold.utils.precision_utils
import
is_fp16_enabled
from
openfold.utils.tensor_utils
import
add
,
permute_final_dims
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
...
@@ -55,12 +57,310 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -55,12 +57,310 @@ class TriangleMultiplicativeUpdate(nn.Module):
def
_combine_projections
(
self
,
def
_combine_projections
(
self
,
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
_inplace_chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
(
"This method needs to be overridden"
)
if
(
self
.
_outgoing
):
a
=
permute_final_dims
(
a
,
(
2
,
0
,
1
))
b
=
permute_final_dims
(
b
,
(
2
,
1
,
0
))
else
:
a
=
permute_final_dims
(
a
,
(
2
,
1
,
0
))
b
=
permute_final_dims
(
b
,
(
2
,
0
,
1
))
if
(
_inplace_chunk_size
is
not
None
):
# To be replaced by torch vmap
for
i
in
range
(
0
,
a
.
shape
[
-
3
],
_inplace_chunk_size
):
a_chunk
=
a
[...,
i
:
i
+
_inplace_chunk_size
,
:,
:]
b_chunk
=
b
[...,
i
:
i
+
_inplace_chunk_size
,
:,
:]
a
[...,
i
:
i
+
_inplace_chunk_size
,
:,
:]
=
(
torch
.
matmul
(
a_chunk
,
b_chunk
,
)
)
p
=
a
else
:
p
=
torch
.
matmul
(
a
,
b
)
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
def
_inference_forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace_chunk_size
:
Optional
[
int
]
=
None
,
with_add
:
bool
=
True
,
):
"""
Args:
z:
A [*, N, N, C_z] pair representation
mask:
A [*, N, N] pair mask
inplace_chunk_size:
Size of chunks used in the main computation. Increase to trade
memory for speed.
with_add:
If True, z is overwritten with (z + update). Otherwise, it is
overwritten with (update).
Returns:
A reference to the overwritten z
More memory-efficient, inference-only version of the forward function.
Uses in-place operations, fusion of the addition that happens after
this module in the Evoformer, a smidge of recomputation, and
a cache of overwritten values to lower peak memory consumption of this
module from 5x the size of the input tensor z to 2.5x its size. Useful
for inference on extremely long sequences.
It works as follows. We will make reference to variables used in the
default forward implementation below. Naively, triangle multiplication
attention requires the manifestation of 5 tensors the size of z:
1) z, the "square" input tensor, 2) a, the first projection of z,
3) b, the second projection of b, 4) g, a z-sized mask, and 5) a
z-sized tensor for intermediate computations. For large N, this is
prohibitively expensive; for N=4000, for example, z is more than 8GB
alone. To avoid this problem, we compute b, g, and all intermediate
tensors in small chunks, noting that the chunks required to compute a
chunk of the output depend only on the tensor a and corresponding
vertical and horizontal chunks of z. This suggests an algorithm that
loops over pairs of chunks of z: hereafter "columns" and "rows" of
z, even though each "column" and "row" in fact contains
inplace_chunk_size contiguous true columns and rows of z. Writing
output chunks to a new tensor would bring total memory consumption
down to 3x the size of z. However, more memory can be saved by writing
output chunks directly to z in-place. WLOG, we choose to write output
chunks vertically, overwriting the ith "column" of z at the end of
the ith iteration of the main loop. Despite this overwriting, the
ith column is always one column ahead of previously overwritten columns
and can be recovered directly from z. After the first iteration,
however, the ith row of z is always at least partially overwritten. For
this reason, we introduce the z-cache, a tensor one-half the size of
z. The z-cache initially contains the left half (2nd and 3rd quadrants)
of z. For 0 < i < N/2, the missing left part of the ith row of z is
recovered from this cache at the beginning of the ith iteration. Once i
exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th
quadrants of z instead. Though the 3rd quadrant of the original z is
entirely overwritten at this point, it can be recovered from the z-cache
itself. Thereafter, the ith row of z can be recovered in its entirety
from the reoriented z-cache. After the final iteration, z has been
completely overwritten and contains the triangular multiplicative
update. If with_add is True, it instead contains the sum of z and the
triangular multiplicative update. In either case, peak memory
consumption is just 2.5x the size of z, disregarding memory used for
chunks and other small variables.
"""
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
def
compute_projection_helper
(
pair
,
mask
,
a
=
True
):
if
(
a
):
linear_g
=
self
.
linear_a_g
linear_p
=
self
.
linear_a_p
else
:
linear_g
=
self
.
linear_b_g
linear_p
=
self
.
linear_b_p
pair
=
self
.
layer_norm_in
(
pair
)
p
=
linear_g
(
pair
)
p
.
sigmoid_
()
p
*=
linear_p
(
pair
)
p
*=
mask
p
=
permute_final_dims
(
p
,
(
2
,
0
,
1
))
return
p
def
compute_projection
(
pair
,
mask
,
a
=
True
,
chunked
=
True
):
need_transpose
=
self
.
_outgoing
^
a
if
(
not
chunked
):
p
=
compute_projection_helper
(
pair
,
mask
,
a
)
if
(
need_transpose
):
p
=
p
.
transpose
(
-
1
,
-
2
)
else
:
# This computation is chunked so as not to exceed our 2.5x
# budget with a large intermediate tensor
linear_g
=
self
.
linear_a_g
if
a
else
self
.
linear_b_g
c
=
linear_g
.
bias
.
shape
[
-
1
]
out_shape
=
pair
.
shape
[:
-
3
]
+
(
c
,)
+
pair
.
shape
[
-
3
:
-
1
]
p
=
pair
.
new_zeros
(
out_shape
)
for
i
in
range
(
0
,
pair
.
shape
[
-
3
],
inplace_chunk_size
):
pair_chunk
=
pair
[...,
i
:
i
+
inplace_chunk_size
,
:,
:]
mask_chunk
=
mask
[...,
i
:
i
+
inplace_chunk_size
,
:,
:]
pair_chunk
=
compute_projection_helper
(
pair
[...,
i
:
i
+
inplace_chunk_size
,
:,
:],
mask
[...,
i
:
i
+
inplace_chunk_size
,
:,
:],
a
,
)
if
(
need_transpose
):
pair_chunk
=
pair_chunk
.
transpose
(
-
1
,
-
2
)
p
[...,
i
:
i
+
inplace_chunk_size
]
=
pair_chunk
else
:
p
[...,
i
:
i
+
inplace_chunk_size
,
:]
=
pair_chunk
del
pair_chunk
return
p
# We start by fully manifesting a. In addition to the input, this
# brings total memory consumption to 2x z (disregarding size of chunks)
# [*, N, N, c]
a
=
compute_projection
(
z
,
mask
,
True
,
chunked
=
True
)
if
(
inplace_chunk_size
is
not
None
):
n
=
a
.
shape
[
-
1
]
half_n
=
n
//
2
+
n
%
2
row_dim
=
-
3
col_dim
=
-
2
b_chunk_dim
=
row_dim
if
self
.
_outgoing
else
col_dim
def
empty_slicer
(
t
):
return
[
slice
(
None
)
for
_
in
t
.
shape
]
def
slice_tensor
(
t
,
start
,
end
,
dim
):
# Slices start:end from the dim dimension of t
s
=
empty_slicer
(
t
)
s
[
dim
]
=
slice
(
start
,
end
)
return
t
[
s
]
def
flip_z_cache_
(
z_cache
,
z
):
# "Reorient" the z_cache (see below), filling it with quadrants
# 3---recovered from the z_cache---and 4---recovered from z---
# of the input tensor z.
quadrant_3
=
slice_tensor
(
z_cache
,
half_n
,
None
,
row_dim
)
z_cache
=
z_cache
.
transpose
(
row_dim
,
col_dim
)
# If n is odd, we need to shrink the z_cache by one row
z_cache
=
z_cache
[...,
:(
n
//
2
),
:,
:]
# Move the 3rd quadrant of z into the
first_half_slicer
=
empty_slicer
(
z_cache
)
first_half_slicer
[
col_dim
]
=
slice
(
0
,
half_n
)
z_cache
[
first_half_slicer
]
=
quadrant_3
# Get the fourth quadrant of z
quadrant_4
=
slice_tensor
(
z
,
half_n
,
None
,
row_dim
)
quadrant_4
=
slice_tensor
(
quadrant_4
,
half_n
,
None
,
col_dim
)
# Insert said quadrant into the rotated z-cache
quadrant_3_slicer
=
empty_slicer
(
z_cache
)
quadrant_3_slicer
[
col_dim
]
=
slice
(
half_n
,
None
)
z_cache
[
quadrant_3_slicer
]
=
quadrant_4
return
z_cache
# Initialize the z cache to the left half of z.
z_cache_shape
=
list
(
z
.
shape
)
z_cache_shape
[
col_dim
]
=
half_n
z_cache
=
z
.
new_zeros
(
z_cache_shape
)
z_cache_slicer
=
empty_slicer
(
z_cache
)
z_cache_slicer
[
col_dim
]
=
slice
(
0
,
half_n
)
z_cache
.
copy_
(
z
[
z_cache_slicer
])
z_cache_rotated
=
False
# We need to reorient the z-cache at the halfway point, and we
# don't want a single chunk to straddle that point. We contract one
# of the chunks in the middle to address that problem.
i_range
=
list
(
range
(
0
,
half_n
,
inplace_chunk_size
))
initial_offsets
=
[
i_2
-
i_1
for
i_1
,
i_2
in
zip
(
i_range
,
i_range
[
1
:]
+
[
half_n
])
]
after_half
=
list
(
range
(
half_n
,
n
,
inplace_chunk_size
))
after_half_offsets
=
[
inplace_chunk_size
for
_
in
after_half
]
combined_range_with_offsets
=
zip
(
i_range
+
after_half
,
initial_offsets
+
after_half_offsets
)
for
i
,
offset
in
combined_range_with_offsets
:
if
(
not
z_cache_rotated
and
i
>=
half_n
):
z_cache
=
flip_z_cache_
(
z_cache
,
z
)
z_cache_rotated
=
True
z_chunk_b
=
slice_tensor
(
z
,
i
,
i
+
offset
,
b_chunk_dim
,
)
mask_chunk
=
slice_tensor
(
mask
,
i
,
i
+
offset
,
b_chunk_dim
,
)
z_chunk_b
=
z_chunk_b
.
clone
()
if
(
b_chunk_dim
==
col_dim
):
z_chunk_b
=
slice_tensor
(
z
,
i
,
i
+
offset
,
col_dim
)
else
:
# b_chunk_dim == row_dim
# In this case, the b-dimension (b_chunk_dim) is partially
# overwritten at the end of each iteration. We need to
# restore the missing component from the z-cache.
if
(
not
z_cache_rotated
):
z_chunk_slicer
=
empty_slicer
(
z_chunk_b
)
z_chunk_slicer
[
col_dim
]
=
slice
(
0
,
half_n
)
z_chunk_b
[
z_chunk_slicer
]
=
slice_tensor
(
z_cache
,
i
,
i
+
offset
,
row_dim
,
)
else
:
z_cache_offset
=
i
-
half_n
z_chunk_b
=
slice_tensor
(
z_cache
,
z_cache_offset
,
z_cache_offset
+
offset
,
row_dim
)
b_chunk
=
compute_projection
(
z_chunk_b
,
mask_chunk
,
a
=
False
,
chunked
=
False
)
del
z_chunk_b
x_chunk
=
torch
.
matmul
(
a
,
b_chunk
,
)
x_chunk
=
permute_final_dims
(
x_chunk
,
(
1
,
2
,
0
))
x_chunk
=
self
.
layer_norm_out
(
x_chunk
)
x_chunk
=
self
.
linear_z
(
x_chunk
)
# The g dimension (col_dim) is parallel to and ahead of the
# overwrites in z. We can extract the g chunk normally.
z_chunk_g
=
slice_tensor
(
z
,
i
,
i
+
offset
,
col_dim
)
g_chunk
=
self
.
linear_g
(
self
.
layer_norm_in
(
z_chunk_g
))
g_chunk
.
sigmoid_
()
del
z_chunk_g
x_chunk
*=
g_chunk
# Write the columns into z in-place
z_slicer
=
empty_slicer
(
z
)
z_slicer
[
col_dim
]
=
slice
(
i
,
i
+
offset
)
if
(
with_add
):
z
[
z_slicer
]
+=
x_chunk
else
:
z
[
z_slicer
]
=
x_chunk
else
:
b
=
compute_projection
(
z
,
mask
,
False
,
False
)
x
=
torch
.
matmul
(
a
,
b
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
linear_g
(
z
)
g
.
sigmoid_
()
x
*=
g
if
(
with_add
):
z
+=
x
else
:
z
=
x
return
z
def
forward
(
self
,
def
forward
(
self
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace_safe
:
bool
=
False
,
_add_with_inplace
:
bool
=
False
,
_inplace_chunk_size
:
Optional
[
int
]
=
256
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -71,57 +371,52 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -71,57 +371,52 @@ class TriangleMultiplicativeUpdate(nn.Module):
Returns:
Returns:
[*, N_res, N_res, C_z] output tensor
[*, N_res, N_res, C_z] output tensor
"""
"""
if
(
inplace_safe
):
x
=
self
.
_inference_forward
(
z
,
mask
,
inplace_chunk_size
=
_inplace_chunk_size
,
with_add
=
_add_with_inplace
,
)
return
x
if
mask
is
None
:
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
mask
=
mask
.
unsqueeze
(
-
1
)
z
=
self
.
layer_norm_in
(
z
)
z
=
self
.
layer_norm_in
(
z
)
a
=
self
.
linear_a_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_a_g
(
z
))
a
=
mask
a
=
a
*
mask
a
=
a
*
self
.
sigmoid
(
self
.
linear_a_g
(
z
))
b
=
self
.
linear_b_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
a
=
a
*
self
.
linear_a_p
(
z
)
b
=
b
*
mask
b
=
mask
x
=
self
.
_combine_projections
(
a
,
b
)
b
=
b
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
self
.
linear_b_p
(
z
)
if
(
is_fp16_enabled
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
_combine_projections
(
a
.
float
(),
b
.
float
())
else
:
x
=
self
.
_combine_projections
(
a
,
b
)
del
a
,
b
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
z
=
x
*
g
x
=
x
*
g
return
z
return
x
class
TriangleMultiplicationOutgoing
(
TriangleMultiplicativeUpdate
):
class
TriangleMultiplicationOutgoing
(
TriangleMultiplicativeUpdate
):
"""
"""
Implements Algorithm 11.
Implements Algorithm 11.
"""
"""
def
_combine_projections
(
self
,
__init__
=
partialmethod
(
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
True
)
a
:
torch
.
Tensor
,
# [*, N_i, N_k, C]
b
:
torch
.
Tensor
,
# [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
0
,
1
)),
permute_final_dims
(
b
,
(
2
,
1
,
0
)),
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
class
TriangleMultiplicationIncoming
(
TriangleMultiplicativeUpdate
):
class
TriangleMultiplicationIncoming
(
TriangleMultiplicativeUpdate
):
"""
"""
Implements Algorithm 12.
Implements Algorithm 12.
"""
"""
def
_combine_projections
(
self
,
__init__
=
partialmethod
(
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
False
)
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
b
:
torch
.
Tensor
,
# [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
1
,
0
)),
permute_final_dims
(
b
,
(
2
,
0
,
1
)),
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
openfold/np/__init__.py
View file @
39a6d0e6
import
os
import
glob
import
importlib
as
importlib
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
__all__
=
[
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
for
_m
in
_modules
:
globals
()[
_m
[
0
]]
=
_m
[
1
]
# Avoid needlessly cluttering the global namespace
del
_files
,
_m
,
_modules
openfold/np/protein.py
View file @
39a6d0e6
...
@@ -16,8 +16,9 @@
...
@@ -16,8 +16,9 @@
"""Protein data type."""
"""Protein data type."""
import
dataclasses
import
dataclasses
import
io
import
io
from
typing
import
Any
,
Mapping
,
Optional
from
typing
import
Any
,
Sequence
,
Mapping
,
Optional
import
re
import
re
import
string
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
from
Bio.PDB
import
PDBParser
from
Bio.PDB
import
PDBParser
...
@@ -51,16 +52,25 @@ class Protein:
...
@@ -51,16 +52,25 @@ class Protein:
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index
:
np
.
ndarray
# [num_res]
residue_index
:
np
.
ndarray
# [num_res]
# 0-indexed number corresponding to the chain in the protein that this
# residue belongs to
chain_index
:
np
.
ndarray
# [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# representing the displacement of the residue from its ground truth mean
# value.
# value.
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
# Chain indices for multi-chain predictions
chain_index
:
Optional
[
np
.
ndarray
]
=
None
# Optional remark about the protein. Included as a comment in output PDB
# files
remark
:
Optional
[
str
]
=
None
# Templates used to generate this protein (prediction-only)
parents
:
Optional
[
Sequence
[
str
]]
=
None
# Chain corresponding to each parent
parents_chain_index
:
Optional
[
Sequence
[
int
]]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
(
len
(
np
.
unique
(
self
.
chain_index
))
>
PDB_MAX_CHAINS
):
if
(
len
(
np
.
unique
(
self
.
chain_index
))
>
PDB_MAX_CHAINS
):
raise
ValueError
(
raise
ValueError
(
...
@@ -104,7 +114,6 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
...
@@ -104,7 +114,6 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
if
(
chain_id
is
not
None
and
chain
.
id
!=
chain_id
):
if
(
chain_id
is
not
None
and
chain
.
id
!=
chain_id
):
continue
continue
for
res
in
chain
:
for
res
in
chain
:
if
res
.
id
[
2
]
!=
" "
:
if
res
.
id
[
2
]
!=
" "
:
raise
ValueError
(
raise
ValueError
(
...
@@ -129,17 +138,32 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
...
@@ -129,17 +138,32 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
if
np
.
sum
(
mask
)
<
0.5
:
if
np
.
sum
(
mask
)
<
0.5
:
# If no known atom positions are reported for the residue then skip it.
# If no known atom positions are reported for the residue then skip it.
continue
continue
aatype
.
append
(
restype_idx
)
aatype
.
append
(
restype_idx
)
atom_positions
.
append
(
pos
)
atom_positions
.
append
(
pos
)
atom_mask
.
append
(
mask
)
atom_mask
.
append
(
mask
)
residue_index
.
append
(
mask
)
residue_index
.
append
(
res
.
id
[
1
]
)
chain_ids
.
append
(
chain
.
id
)
chain_ids
.
append
(
chain
.
id
)
b_factors
.
append
(
res_b_factors
)
b_factors
.
append
(
res_b_factors
)
# Chain IDs are usually characters so map these to ints
parents
=
None
parents_chain_index
=
None
if
(
"PARENT"
in
pdb_str
):
parents
=
[]
parents_chain_index
=
[]
chain_id
=
0
for
l
in
pdb_str
.
split
(
"
\n
"
):
if
(
"PARENT"
in
l
):
if
(
not
"N/A"
in
l
):
parent_names
=
l
.
split
()[
1
:]
parents
.
extend
(
parent_names
)
parents_chain_index
.
extend
([
chain_id
for
_
in
parent_names
])
chain_id
+=
1
unique_chain_ids
=
np
.
unique
(
chain_ids
)
unique_chain_ids
=
np
.
unique
(
chain_ids
)
chain_id_mapping
=
{
cid
:
n
for
n
,
cid
in
enumerate
(
unique_chain_ids
)}
chain_id_mapping
=
{
cid
:
n
for
n
,
cid
in
enumerate
(
string
.
ascii_uppercase
)}
chain_index
=
np
.
array
([
chain_id_mapping
[
cid
]
for
cid
in
chain_ids
])
chain_index
=
np
.
array
([
chain_id_mapping
[
cid
]
for
cid
in
chain_ids
])
return
Protein
(
return
Protein
(
...
@@ -149,6 +173,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
...
@@ -149,6 +173,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
residue_index
=
np
.
array
(
residue_index
),
residue_index
=
np
.
array
(
residue_index
),
chain_index
=
chain_index
,
chain_index
=
chain_index
,
b_factors
=
np
.
array
(
b_factors
),
b_factors
=
np
.
array
(
b_factors
),
parents
=
parents
,
parents_chain_index
=
parents_chain_index
,
)
)
...
@@ -213,6 +239,78 @@ def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
...
@@ -213,6 +239,78 @@ def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
)
)
def
get_pdb_headers
(
prot
:
Protein
,
chain_id
:
int
=
0
)
->
Sequence
[
str
]:
pdb_headers
=
[]
remark
=
prot
.
remark
if
(
remark
is
not
None
):
pdb_headers
.
append
(
f
"REMARK
{
remark
}
"
)
parents
=
prot
.
parents
parents_chain_index
=
prot
.
parents_chain_index
if
(
parents_chain_index
is
not
None
):
parents
=
[
p
for
i
,
p
in
zip
(
parents_chain_index
,
parents
)
if
i
==
chain_id
]
if
(
parents
is
None
or
len
(
parents
)
==
0
):
parents
=
[
"N/A"
]
pdb_headers
.
append
(
f
"PARENT
{
' '
.
join
(
parents
)
}
"
)
return
pdb_headers
def
add_pdb_headers
(
prot
:
Protein
,
pdb_str
:
str
)
->
str
:
""" Add pdb headers to an existing PDB string. Useful during multi-chain
recycling
"""
out_pdb_lines
=
[]
lines
=
pdb_str
.
split
(
'
\n
'
)
remark
=
prot
.
remark
if
(
remark
is
not
None
):
out_pdb_lines
.
append
(
f
"REMARK
{
remark
}
"
)
parents_per_chain
=
None
if
(
prot
.
parents
is
not
None
and
len
(
prot
.
parents
)
>
0
):
parents_per_chain
=
[]
if
(
prot
.
parents_chain_index
is
not
None
):
cur_chain
=
prot
.
parents_chain_index
[
0
]
parent_dict
=
{}
for
p
,
i
in
zip
(
prot
.
parents
,
prot
.
parents_chain_index
):
parent_dict
.
setdefault
(
str
(
i
),
[])
parent_dict
[
str
(
i
)].
append
(
p
)
max_idx
=
max
([
int
(
chain_idx
)
for
chain_idx
in
parent_dict
])
for
i
in
range
(
max_idx
+
1
):
chain_parents
=
parent_dict
.
get
(
str
(
i
),
[
"N/A"
])
parents_per_chain
.
append
(
chain_parents
)
else
:
parents_per_chain
.
append
(
prot
.
parents
)
else
:
parents_per_chain
=
[[
"N/A"
]]
make_parent_line
=
lambda
p
:
f
"PARENT
{
' '
.
join
(
p
)
}
"
out_pdb_lines
.
append
(
make_parent_line
(
parents_per_chain
[
0
]))
chain_counter
=
0
for
i
,
l
in
enumerate
(
lines
):
if
(
"PARENT"
not
in
l
and
"REMARK"
not
in
l
):
out_pdb_lines
.
append
(
l
)
if
(
"TER"
in
l
and
not
"END"
in
lines
[
i
+
1
]):
chain_counter
+=
1
if
(
not
chain_counter
>=
len
(
parents_per_chain
)):
chain_parents
=
parents_per_chain
[
chain_counter
]
else
:
chain_parents
=
[
"N/A"
]
out_pdb_lines
.
append
(
make_parent_line
(
chain_parents
))
return
'
\n
'
.
join
(
out_pdb_lines
)
def
to_pdb
(
prot
:
Protein
)
->
str
:
def
to_pdb
(
prot
:
Protein
)
->
str
:
"""Converts a `Protein` instance to a PDB string.
"""Converts a `Protein` instance to a PDB string.
...
@@ -232,8 +330,8 @@ def to_pdb(prot: Protein) -> str:
...
@@ -232,8 +330,8 @@ def to_pdb(prot: Protein) -> str:
aatype
=
prot
.
aatype
aatype
=
prot
.
aatype
atom_positions
=
prot
.
atom_positions
atom_positions
=
prot
.
atom_positions
residue_index
=
prot
.
residue_index
.
astype
(
np
.
int32
)
residue_index
=
prot
.
residue_index
.
astype
(
np
.
int32
)
chain_index
=
prot
.
chain_index
.
astype
(
np
.
int32
)
b_factors
=
prot
.
b_factors
b_factors
=
prot
.
b_factors
chain_index
=
prot
.
chain_index
.
astype
(
np
.
int32
)
if
np
.
any
(
aatype
>
residue_constants
.
restype_num
):
if
np
.
any
(
aatype
>
residue_constants
.
restype_num
):
raise
ValueError
(
"Invalid aatypes."
)
raise
ValueError
(
"Invalid aatypes."
)
...
@@ -247,9 +345,17 @@ def to_pdb(prot: Protein) -> str:
...
@@ -247,9 +345,17 @@ def to_pdb(prot: Protein) -> str:
)
)
chain_ids
[
i
]
=
PDB_CHAIN_IDS
[
i
]
chain_ids
[
i
]
=
PDB_CHAIN_IDS
[
i
]
headers
=
get_pdb_headers
(
prot
)
if
(
len
(
headers
)
>
0
):
pdb_lines
.
extend
(
headers
)
pdb_lines
.
append
(
"MODEL 1"
)
pdb_lines
.
append
(
"MODEL 1"
)
n
=
aatype
.
shape
[
0
]
atom_index
=
1
atom_index
=
1
last_chain_index
=
chain_index
[
0
]
last_chain_index
=
chain_index
[
0
]
prev_chain_index
=
0
chain_tags
=
string
.
ascii_uppercase
# Add all atom sites.
# Add all atom sites.
for
i
in
range
(
aatype
.
shape
[
0
]):
for
i
in
range
(
aatype
.
shape
[
0
]):
# Close the previous chain if in a multichain PDB.
# Close the previous chain if in a multichain PDB.
...
@@ -281,10 +387,17 @@ def to_pdb(prot: Protein) -> str:
...
@@ -281,10 +387,17 @@ def to_pdb(prot: Protein) -> str:
0
0
]
# Protein supports only C, N, O, S, this works.
]
# Protein supports only C, N, O, S, this works.
charge
=
""
charge
=
""
chain_tag
=
"A"
if
(
chain_index
is
not
None
):
chain_tag
=
chain_tags
[
chain_index
[
i
]]
# PDB is a columnar format, every space matters here!
# PDB is a columnar format, every space matters here!
atom_line
=
(
atom_line
=
(
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
f
"
{
res_name_3
:
>
3
}
{
chain_ids
[
chain_index
[
i
]]:
>
1
}
"
#TODO: check this refactor, chose main branch version
#f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
f
"
{
res_name_3
:
>
3
}
{
chain_tag
:
>
1
}
"
f
"
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
"
f
"
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
"
f
"
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
"
f
"
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
"
f
"
{
occupancy
:
>
6.2
f
}{
b_factor
:
>
6.2
f
}
"
f
"
{
occupancy
:
>
6.2
f
}{
b_factor
:
>
6.2
f
}
"
...
@@ -293,16 +406,28 @@ def to_pdb(prot: Protein) -> str:
...
@@ -293,16 +406,28 @@ def to_pdb(prot: Protein) -> str:
pdb_lines
.
append
(
atom_line
)
pdb_lines
.
append
(
atom_line
)
atom_index
+=
1
atom_index
+=
1
# Close the final chain.
should_terminate
=
(
i
==
n
-
1
)
pdb_lines
.
append
(
if
(
chain_index
is
not
None
):
_chain_end
(
if
(
i
!=
n
-
1
and
chain_index
[
i
+
1
]
!=
prev_chain_index
):
atom_index
,
should_terminate
=
True
res_1to3
(
aatype
[
-
1
]),
prev_chain_index
=
chain_index
[
i
+
1
]
chain_ids
[
chain_index
[
-
1
]],
residue_index
[
-
1
]
if
(
should_terminate
):
)
# Close the chain.
)
chain_end
=
"TER"
chain_termination_line
=
(
f
"
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
"
f
"
{
res_1to3
(
aatype
[
i
]):
>
3
}
"
f
"
{
chain_tag
:
>
1
}{
residue_index
[
i
]:
>
4
}
"
)
pdb_lines
.
append
(
chain_termination_line
)
atom_index
+=
1
if
(
i
!=
n
-
1
):
# "prev" is a misnomer here. This happens at the beginning of
# each new chain.
pdb_lines
.
extend
(
get_pdb_headers
(
prot
,
prev_chain_index
))
pdb_lines
.
append
(
"ENDMDL"
)
pdb_lines
.
append
(
"ENDMDL"
)
pdb_lines
.
append
(
"END"
)
pdb_lines
.
append
(
"END"
)
...
@@ -332,6 +457,9 @@ def from_prediction(
...
@@ -332,6 +457,9 @@ def from_prediction(
result
:
ModelOutput
,
result
:
ModelOutput
,
b_factors
:
Optional
[
np
.
ndarray
]
=
None
,
b_factors
:
Optional
[
np
.
ndarray
]
=
None
,
remove_leading_feature_dimension
:
bool
=
True
,
remove_leading_feature_dimension
:
bool
=
True
,
remark
:
Optional
[
str
]
=
None
,
parents
:
Optional
[
Sequence
[
str
]]
=
None
,
parents_chain_index
:
Optional
[
Sequence
[
int
]]
=
None
)
->
Protein
:
)
->
Protein
:
"""Assembles a protein from a prediction.
"""Assembles a protein from a prediction.
...
@@ -341,7 +469,9 @@ def from_prediction(
...
@@ -341,7 +469,9 @@ def from_prediction(
b_factors: (Optional) B-factors to use for the protein.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
of the `features` values
chain_index: (Optional) Chain indices for multi-chain predictions
remark: (Optional) Remark about the prediction
parents: (Optional) List of template names
Returns:
Returns:
A protein instance.
A protein instance.
"""
"""
...
@@ -349,7 +479,7 @@ def from_prediction(
...
@@ -349,7 +479,7 @@ def from_prediction(
return
arr
[
0
]
if
remove_leading_feature_dimension
else
arr
return
arr
[
0
]
if
remove_leading_feature_dimension
else
arr
if
'asym_id'
in
features
:
if
'asym_id'
in
features
:
chain_index
=
_maybe_remove_leading_dim
(
features
[
"asym_id"
])
chain_index
=
_maybe_remove_leading_dim
(
features
[
"asym_id"
])
-
1
else
:
else
:
chain_index
=
np
.
zeros_like
(
chain_index
=
np
.
zeros_like
(
_maybe_remove_leading_dim
(
features
[
"aatype"
])
_maybe_remove_leading_dim
(
features
[
"aatype"
])
...
@@ -363,6 +493,9 @@ def from_prediction(
...
@@ -363,6 +493,9 @@ def from_prediction(
atom_positions
=
result
[
"final_atom_positions"
],
atom_positions
=
result
[
"final_atom_positions"
],
atom_mask
=
result
[
"final_atom_mask"
],
atom_mask
=
result
[
"final_atom_mask"
],
residue_index
=
_maybe_remove_leading_dim
(
features
[
"residue_index"
])
+
1
,
residue_index
=
_maybe_remove_leading_dim
(
features
[
"residue_index"
])
+
1
,
chain_index
=
chain_index
,
b_factors
=
b_factors
,
b_factors
=
b_factors
,
chain_index
=
chain_index
,
remark
=
remark
,
parents
=
parents
,
parents_chain_index
=
parents_chain_index
,
)
)
openfold/np/relax/__init__.py
View file @
39a6d0e6
import
os
import
glob
import
importlib
as
importlib
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
__all__
=
[
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
for
_m
in
_modules
:
globals
()[
_m
[
0
]]
=
_m
[
1
]
# Avoid needlessly cluttering the global namespace
del
_files
,
_m
,
_modules
openfold/np/relax/amber_minimize.py
View file @
39a6d0e6
...
@@ -28,10 +28,18 @@ import openfold.utils.loss as loss
...
@@ -28,10 +28,18 @@ import openfold.utils.loss as loss
from
openfold.np.relax
import
cleanup
,
utils
from
openfold.np.relax
import
cleanup
,
utils
import
ml_collections
import
ml_collections
import
numpy
as
np
import
numpy
as
np
from
simtk
import
openmm
try
:
from
simtk
import
unit
# openmm >= 7.6
from
simtk.openmm
import
app
as
openmm_app
import
openmm
from
simtk.openmm.app.internal.pdbstructure
import
PdbStructure
from
openmm
import
unit
from
openmm
import
app
as
openmm_app
from
openmm.app.internal.pdbstructure
import
PdbStructure
except
ImportError
:
# openmm < 7.6 (requires DeepMind patch)
from
simtk
import
openmm
from
simtk
import
unit
from
simtk.openmm
import
app
as
openmm_app
from
simtk.openmm.app.internal.pdbstructure
import
PdbStructure
ENERGY
=
unit
.
kilocalories_per_mole
ENERGY
=
unit
.
kilocalories_per_mole
LENGTH
=
unit
.
angstroms
LENGTH
=
unit
.
angstroms
...
@@ -192,6 +200,11 @@ def clean_protein(prot: protein.Protein, checks: bool = True):
...
@@ -192,6 +200,11 @@ def clean_protein(prot: protein.Protein, checks: bool = True):
pdb_string
=
_get_pdb_string
(
as_file
.
getTopology
(),
as_file
.
getPositions
())
pdb_string
=
_get_pdb_string
(
as_file
.
getTopology
(),
as_file
.
getPositions
())
if
checks
:
if
checks
:
_check_cleaned_atoms
(
pdb_string
,
prot_pdb_string
)
_check_cleaned_atoms
(
pdb_string
,
prot_pdb_string
)
headers
=
protein
.
get_pdb_headers
(
prot
)
if
(
len
(
headers
)
>
0
):
pdb_string
=
'
\n
'
.
join
([
'
\n
'
.
join
(
headers
),
pdb_string
])
return
pdb_string
return
pdb_string
...
@@ -511,6 +524,9 @@ def run_pipeline(
...
@@ -511,6 +524,9 @@ def run_pipeline(
_check_residues_are_well_defined
(
prot
)
_check_residues_are_well_defined
(
prot
)
pdb_string
=
clean_protein
(
prot
,
checks
=
checks
)
pdb_string
=
clean_protein
(
prot
,
checks
=
checks
)
# We keep the input around to restore metadata deleted by the relaxer
input_prot
=
prot
exclude_residues
=
exclude_residues
or
[]
exclude_residues
=
exclude_residues
or
[]
exclude_residues
=
set
(
exclude_residues
)
exclude_residues
=
set
(
exclude_residues
)
violations
=
np
.
inf
violations
=
np
.
inf
...
@@ -527,6 +543,11 @@ def run_pipeline(
...
@@ -527,6 +543,11 @@ def run_pipeline(
max_attempts
=
max_attempts
,
max_attempts
=
max_attempts
,
use_gpu
=
use_gpu
,
use_gpu
=
use_gpu
,
)
)
headers
=
protein
.
get_pdb_headers
(
prot
)
if
(
len
(
headers
)
>
0
):
ret
[
"min_pdb"
]
=
'
\n
'
.
join
([
'
\n
'
.
join
(
headers
),
ret
[
"min_pdb"
]])
prot
=
protein
.
from_pdb_string
(
ret
[
"min_pdb"
])
prot
=
protein
.
from_pdb_string
(
ret
[
"min_pdb"
])
if
place_hydrogens_every_iteration
:
if
place_hydrogens_every_iteration
:
pdb_string
=
clean_protein
(
prot
,
checks
=
True
)
pdb_string
=
clean_protein
(
prot
,
checks
=
True
)
...
...
openfold/np/relax/cleanup.py
View file @
39a6d0e6
...
@@ -20,8 +20,14 @@ cases like removing chains of length one (see clean_structure).
...
@@ -20,8 +20,14 @@ cases like removing chains of length one (see clean_structure).
import
io
import
io
import
pdbfixer
import
pdbfixer
from
simtk.openmm
import
app
try
:
from
simtk.openmm.app
import
element
# openmm >= 7.6
from
openmm
import
app
from
openmm.app
import
element
except
ImportError
:
# openmm < 7.6 (requires DeepMind patch)
from
simtk.openmm
import
app
from
simtk.openmm.app
import
element
def
fix_pdb
(
pdbfile
,
alterations_info
):
def
fix_pdb
(
pdbfile
,
alterations_info
):
...
...
openfold/np/relax/relax.py
View file @
39a6d0e6
...
@@ -87,4 +87,7 @@ class AmberRelaxation(object):
...
@@ -87,4 +87,7 @@ class AmberRelaxation(object):
violations
=
out
[
"structural_violations"
][
violations
=
out
[
"structural_violations"
][
"total_per_residue_violations_mask"
"total_per_residue_violations_mask"
]
]
min_pdb
=
protein
.
add_pdb_headers
(
prot
,
min_pdb
)
return
min_pdb
,
debug_data
,
violations
return
min_pdb
,
debug_data
,
violations
openfold/np/relax/utils.py
View file @
39a6d0e6
...
@@ -18,8 +18,14 @@ import io
...
@@ -18,8 +18,14 @@ import io
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
from
Bio
import
PDB
from
Bio
import
PDB
import
numpy
as
np
import
numpy
as
np
from
simtk.openmm
import
app
as
openmm_app
try
:
from
simtk.openmm.app.internal.pdbstructure
import
PdbStructure
# openmm >= 7.6
from
openmm
import
app
as
openmm_app
from
openmm.app.internal.pdbstructure
import
PdbStructure
except
ImportError
:
# openmm < 7.6 (requires DeepMind patch)
from
simtk.openmm
import
app
as
openmm_app
from
simtk.openmm.app.internal.pdbstructure
import
PdbStructure
def
overwrite_pdb_coordinates
(
pdb_str
:
str
,
pos
)
->
str
:
def
overwrite_pdb_coordinates
(
pdb_str
:
str
,
pos
)
->
str
:
...
...
openfold/np/residue_constants.py
View file @
39a6d0e6
...
@@ -1120,10 +1120,10 @@ def _make_rigid_transformation_4x4(ex, ey, translation):
...
@@ -1120,10 +1120,10 @@ def _make_rigid_transformation_4x4(ex, ey, translation):
# and an array with (restype, atomtype, coord) for the atom positions
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
# previous group
restype_atom37_to_rigid_group
=
np
.
zeros
([
21
,
37
],
dtype
=
np
.
int
)
restype_atom37_to_rigid_group
=
np
.
zeros
([
21
,
37
],
dtype
=
int
)
restype_atom37_mask
=
np
.
zeros
([
21
,
37
],
dtype
=
np
.
float32
)
restype_atom37_mask
=
np
.
zeros
([
21
,
37
],
dtype
=
np
.
float32
)
restype_atom37_rigid_group_positions
=
np
.
zeros
([
21
,
37
,
3
],
dtype
=
np
.
float32
)
restype_atom37_rigid_group_positions
=
np
.
zeros
([
21
,
37
,
3
],
dtype
=
np
.
float32
)
restype_atom14_to_rigid_group
=
np
.
zeros
([
21
,
14
],
dtype
=
np
.
int
)
restype_atom14_to_rigid_group
=
np
.
zeros
([
21
,
14
],
dtype
=
int
)
restype_atom14_mask
=
np
.
zeros
([
21
,
14
],
dtype
=
np
.
float32
)
restype_atom14_mask
=
np
.
zeros
([
21
,
14
],
dtype
=
np
.
float32
)
restype_atom14_rigid_group_positions
=
np
.
zeros
([
21
,
14
,
3
],
dtype
=
np
.
float32
)
restype_atom14_rigid_group_positions
=
np
.
zeros
([
21
,
14
,
3
],
dtype
=
np
.
float32
)
restype_rigid_group_default_frame
=
np
.
zeros
([
21
,
8
,
4
,
4
],
dtype
=
np
.
float32
)
restype_rigid_group_default_frame
=
np
.
zeros
([
21
,
8
,
4
,
4
],
dtype
=
np
.
float32
)
...
@@ -1279,7 +1279,7 @@ def make_atom14_dists_bounds(
...
@@ -1279,7 +1279,7 @@ def make_atom14_dists_bounds(
restype_atom14_ambiguous_atoms
=
np
.
zeros
((
21
,
14
),
dtype
=
np
.
float32
)
restype_atom14_ambiguous_atoms
=
np
.
zeros
((
21
,
14
),
dtype
=
np
.
float32
)
restype_atom14_ambiguous_atoms_swap_idx
=
np
.
tile
(
restype_atom14_ambiguous_atoms_swap_idx
=
np
.
tile
(
np
.
arange
(
14
,
dtype
=
np
.
int
),
(
21
,
1
)
np
.
arange
(
14
,
dtype
=
int
),
(
21
,
1
)
)
)
...
...
openfold/utils/__init__.py
View file @
39a6d0e6
import
os
import
glob
import
importlib
as
importlib
from
.
import
kernel
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
__all__
=
[
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
]
+
[
"kernel"
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
for
_m
in
_modules
:
globals
()[
_m
[
0
]]
=
_m
[
1
]
# Avoid needlessly cluttering the global namespace
del
_files
,
_m
,
_modules
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment