Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
722a5e01
Commit
722a5e01
authored
May 12, 2022
by
Gustaf Ahdritz
Browse files
Improve ease of use of LMA
parent
237e26c4
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
208 additions
and
114 deletions
+208
-114
openfold/config.py
openfold/config.py
+2
-0
openfold/model/evoformer.py
openfold/model/evoformer.py
+70
-12
openfold/model/model.py
openfold/model/model.py
+4
-0
openfold/model/msa.py
openfold/model/msa.py
+20
-8
openfold/model/primitives.py
openfold/model/primitives.py
+45
-28
openfold/model/template.py
openfold/model/template.py
+13
-6
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+6
-4
run_pretrained_openfold.py
run_pretrained_openfold.py
+45
-36
tests/test_primitives.py
tests/test_primitives.py
+3
-20
No files found.
openfold/config.py
View file @
722a5e01
...
...
@@ -80,6 +80,7 @@ def model_config(name, train=False, low_prec=False):
if
train
:
c
.
globals
.
blocks_per_ckpt
=
1
c
.
globals
.
chunk_size
=
None
c
.
globals
.
use_lma
=
False
if
low_prec
:
c
.
globals
.
eps
=
1e-4
...
...
@@ -269,6 +270,7 @@ config = mlc.ConfigDict(
"globals"
:
{
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"use_lma"
:
False
,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"c_t"
:
c_t
,
...
...
openfold/model/evoformer.py
View file @
722a5e01
...
...
@@ -183,6 +183,7 @@ class EvoformerBlockCore(nn.Module):
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
...
...
@@ -192,21 +193,31 @@ class EvoformerBlockCore(nn.Module):
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
m
=
m
+
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
,
)
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
)
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_att_start
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
)
self
.
tri_att_start
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
)
)
z
=
z
+
self
.
ps_dropout_col_layer
(
self
.
tri_att_end
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
)
self
.
tri_att_end
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
)
)
z
=
z
+
self
.
pair_transition
(
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
,
)
return
m
,
z
...
...
@@ -267,18 +278,31 @@ class EvoformerBlock(nn.Module):
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
)
)
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
)
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
,
z
=
self
.
core
(
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
)
...
...
@@ -350,7 +374,9 @@ class ExtraMSABlock(nn.Module):
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
_chunk_logits
:
Optional
[
int
]
=
1024
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
add
(
m1
,
m2
):
# The first operation in a checkpoint can't be in-place, but it's
...
...
@@ -368,7 +394,8 @@ class ExtraMSABlock(nn.Module):
z
=
z
.
clone
()
if
torch
.
is_grad_enabled
()
else
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_memory_efficient_kernel
=
not
_chunk_logits
,
use_lma
=
use_lma
,
use_memory_efficient_kernel
=
not
_chunk_logits
and
not
use_lma
,
_chunk_logits
=
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
...
...
@@ -376,9 +403,23 @@ class ExtraMSABlock(nn.Module):
))
def
fn
(
m
,
z
):
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
))
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
)
)
m
,
z
=
self
.
core
(
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
)
return
m
,
z
...
...
@@ -488,6 +529,7 @@ class EvoformerStack(nn.Module):
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
...
...
@@ -500,6 +542,8 @@ class EvoformerStack(nn.Module):
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
chunk_size: Inference-time subbatch size
use_lma: Whether to use low-memory attention during inference
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
...
...
@@ -514,6 +558,7 @@ class EvoformerStack(nn.Module):
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
...
...
@@ -591,6 +636,7 @@ class ExtraMSAStack(nn.Module):
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_mask_trans
:
bool
=
True
,
...
...
@@ -601,6 +647,8 @@ class ExtraMSAStack(nn.Module):
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules
use_lma: Whether to use low-memory attention during inference
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
...
...
@@ -616,7 +664,9 @@ class ExtraMSAStack(nn.Module):
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_chunk_logits
=
None
use_lma
=
use_lma
,
_chunk_logits
=
None
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
...
...
@@ -634,7 +684,15 @@ class ExtraMSAStack(nn.Module):
m
,
z
=
b
(
m
,
z
)
else
:
for
b
in
self
.
blocks
:
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
)
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
)
if
(
self
.
clear_cache_between_blocks
):
torch
.
cuda
.
empty_cache
()
...
...
openfold/model/model.py
View file @
722a5e01
...
...
@@ -152,6 +152,7 @@ class AlphaFold(nn.Module):
template_embeds
[
"pair"
],
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
...
...
@@ -161,6 +162,7 @@ class AlphaFold(nn.Module):
z
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
...
...
@@ -294,6 +296,7 @@ class AlphaFold(nn.Module):
z
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
a
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
)
...
...
@@ -308,6 +311,7 @@ class AlphaFold(nn.Module):
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
...
...
openfold/model/msa.py
View file @
722a5e01
...
...
@@ -90,12 +90,14 @@ class MSAAttention(nn.Module):
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
use_memory_efficient_kernel
:
bool
,
chunk_size
:
int
,
use_memory_efficient_kernel
:
bool
,
use_lma
:
bool
,
)
->
torch
.
Tensor
:
mha
=
partial
(
self
.
mha
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
)
return
chunk_layer
(
mha
,
...
...
@@ -193,6 +195,7 @@ class MSAAttention(nn.Module):
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -224,13 +227,20 @@ class MSAAttention(nn.Module):
biases
.
append
(
z
)
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
biases
,
use_memory_efficient_kernel
,
chunk_size
)
m
=
self
.
_chunk
(
m
,
biases
,
chunk_size
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
)
else
:
m
=
self
.
mha
(
q_x
=
m
,
kv_x
=
m
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
)
return
m
...
...
@@ -305,7 +315,7 @@ class MSAColumnAttention(nn.Module):
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_
memory_efficient_kernel
:
bool
=
False
,
use_
lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -323,7 +333,7 @@ class MSAColumnAttention(nn.Module):
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
2
)
m
=
self
.
_msa_att
(
m
,
mask
=
mask
,
chunk_size
=
chunk_size
)
m
=
self
.
_msa_att
(
m
,
mask
=
mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
)
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
...
...
@@ -360,13 +370,14 @@ class MSAColumnGlobalAttention(nn.Module):
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
mha_input
=
{
"m"
:
m
,
"mask"
:
mask
,
}
return
chunk_layer
(
self
.
global_attention
,
partial
(
self
.
global_attention
,
use_lma
=
use_lma
),
mha_input
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
...
...
@@ -377,6 +388,7 @@ class MSAColumnGlobalAttention(nn.Module):
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
...
...
@@ -396,9 +408,9 @@ class MSAColumnGlobalAttention(nn.Module):
m
=
self
.
layer_norm_m
(
m
)
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
)
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
,
use_lma
=
use_lma
)
else
:
m
=
self
.
global_attention
(
m
=
m
,
mask
=
mask
)
m
=
self
.
global_attention
(
m
=
m
,
mask
=
mask
,
use_lma
=
use_lma
)
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
...
...
openfold/model/primitives.py
View file @
722a5e01
...
...
@@ -31,6 +31,10 @@ from openfold.utils.tensor_utils import (
)
DEFAULT_LMA_Q_CHUNK_SIZE
=
1024
DEFAULT_LMA_KV_CHUNK_SIZE
=
4096
def
_prod
(
nums
):
out
=
1
for
n
in
nums
:
...
...
@@ -403,8 +407,8 @@ class Attention(nn.Module):
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
q_chunk_size
:
Optional
[
int
]
=
None
,
kv_chunk_size
:
Optional
[
int
]
=
None
,
q_chunk_size
:
int
=
DEFAULT_LMA_Q_CHUNK_SIZE
,
kv_chunk_size
:
int
=
DEFAULT_LMA_KV_CHUNK_SIZE
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -460,6 +464,7 @@ class Attention(nn.Module):
for
b
in
biases
]
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
o
=
o
.
transpose
(
-
2
,
-
3
)
else
:
o
=
_attention
(
q
,
k
,
v
,
biases
)
o
=
o
.
transpose
(
-
2
,
-
3
)
...
...
@@ -494,7 +499,11 @@ class GlobalAttention(nn.Module):
self
.
sigmoid
=
nn
.
Sigmoid
()
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
# [*, N_res, C_in]
q
=
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
...
...
@@ -511,20 +520,30 @@ class GlobalAttention(nn.Module):
k
=
self
.
linear_k
(
m
)
v
=
self
.
linear_v
(
m
)
# [*, N_res, H, N_seq]
a
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
+=
bias
a
=
softmax_no_cast
(
a
)
if
(
not
use_lma
):
# [*, N_res, H, N_seq]
a
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
)
a
+=
bias
a
=
softmax_no_cast
(
a
)
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
a
,
v
,
)
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
a
,
v
,
)
else
:
o
=
_lma
(
q
,
k
,
v
,
[
bias
],
DEFAULT_LMA_Q_CHUNK_SIZE
,
DEFAULT_LMA_KV_CHUNK_SIZE
)
# [*, N_res, N_seq, C_hidden]
g
=
self
.
sigmoid
(
self
.
linear_g
(
m
))
...
...
@@ -552,12 +571,12 @@ def _lma(
q_chunk_size
:
int
,
kv_chunk_size
:
int
,
):
no_q
,
no_kv
=
q
.
shape
[
-
3
],
k
.
shape
[
-
3
]
no_q
,
no_kv
=
q
.
shape
[
-
2
],
k
.
shape
[
-
2
]
# [*,
Q
,
H
, C_hidden]
# [*,
H
,
Q
, C_hidden]
o
=
q
.
new_zeros
(
q
.
shape
)
for
q_s
in
range
(
0
,
no_q
,
q_chunk_size
):
q_chunk
=
q
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
q_chunk
=
q
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
large_bias_chunks
=
[
b
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
for
b
in
biases
]
...
...
@@ -566,24 +585,22 @@ def _lma(
weights
=
[]
values
=
[]
for
kv_s
in
range
(
0
,
no_kv
,
kv_chunk_size
):
k_chunk
=
k
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
v_chunk
=
v
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
k_chunk
=
k
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:]
v_chunk
=
v
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:]
small_bias_chunks
=
[
b
[...,
kv_s
:
kv_s
+
kv_chunk_size
]
for
b
in
large_bias_chunks
]
a
=
torch
.
einsum
(
"...
q
hd,...
k
hd->...hqk"
,
q_chunk
,
k_chunk
,
"...h
q
d,...h
k
d->...hqk"
,
q_chunk
,
k_chunk
,
)
for
b
in
small_bias_chunks
:
a
+=
b
a
=
a
.
transpose
(
-
2
,
-
3
)
max_a
=
torch
.
max
(
a
,
dim
=-
1
,
keepdim
=
True
)[
0
]
exp_a
=
torch
.
exp
(
a
-
max_a
)
exp_v
=
torch
.
einsum
(
"...
v
hf,...
q
hv->...
q
hf"
,
v_chunk
,
exp_a
)
exp_v
=
torch
.
einsum
(
"...h
v
f,...h
q
v->...h
q
f"
,
v_chunk
,
exp_a
)
maxes
.
append
(
max_a
.
detach
().
squeeze
(
-
1
))
weights
.
append
(
torch
.
sum
(
exp_a
,
dim
=-
1
))
...
...
@@ -595,14 +612,14 @@ def _lma(
global_max
=
torch
.
max
(
chunk_max
,
dim
=-
3
,
keepdim
=
True
)[
0
]
max_diffs
=
torch
.
exp
(
chunk_max
-
global_max
)
chunk_values
*
=
max_diffs
.
unsqueeze
(
-
1
)
chunk_weights
*
=
max_diffs
chunk_values
=
chunk_values
*
max_diffs
.
unsqueeze
(
-
1
)
chunk_weights
=
chunk_weights
*
max_diffs
all_values
=
torch
.
sum
(
chunk_values
,
dim
=-
4
)
all_weights
=
torch
.
sum
(
chunk_weights
.
unsqueeze
(
-
1
),
dim
=-
4
)
q_chunk_out
=
all_values
/
all_weights
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
=
q_chunk_out
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
=
q_chunk_out
return
o
openfold/model/template.py
View file @
722a5e01
...
...
@@ -77,6 +77,7 @@ class TemplatePointwiseAttention(nn.Module):
t
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
mha_inputs
=
{
"q_x"
:
z
,
...
...
@@ -84,7 +85,7 @@ class TemplatePointwiseAttention(nn.Module):
"biases"
:
biases
,
}
return
chunk_layer
(
self
.
mha
,
partial
(
self
.
mha
,
use_lma
=
use_lma
),
mha_inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
...
...
@@ -95,7 +96,8 @@ class TemplatePointwiseAttention(nn.Module):
t
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
template_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -122,9 +124,9 @@ class TemplatePointwiseAttention(nn.Module):
# [*, N_res, N_res, 1, C_z]
biases
=
[
bias
]
if
chunk_size
is
not
None
:
z
=
self
.
_chunk
(
z
,
t
,
biases
,
chunk_size
)
z
=
self
.
_chunk
(
z
,
t
,
biases
,
chunk_size
,
use_lma
=
use_lma
)
else
:
z
=
self
.
mha
(
q_x
=
z
,
kv_x
=
t
,
biases
=
biases
)
z
=
self
.
mha
(
q_x
=
z
,
kv_x
=
t
,
biases
=
biases
,
use_lma
=
use_lma
)
# [*, N_res, N_res, C_z]
z
=
z
.
squeeze
(
-
2
)
...
...
@@ -188,6 +190,7 @@ class TemplatePairStackBlock(nn.Module):
z
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
):
single_templates
=
[
...
...
@@ -204,14 +207,16 @@ class TemplatePairStackBlock(nn.Module):
self
.
tri_att_start
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
mask
=
single_mask
,
use_lma
=
use_lma
,
)
)
single
=
single
+
self
.
dropout_col
(
self
.
tri_att_end
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
mask
=
single_mask
,
use_lma
=
use_lma
,
)
)
single
=
single
+
self
.
dropout_row
(
...
...
@@ -298,6 +303,7 @@ class TemplatePairStack(nn.Module):
t
:
torch
.
tensor
,
mask
:
torch
.
tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
):
"""
...
...
@@ -320,6 +326,7 @@ class TemplatePairStack(nn.Module):
b
,
mask
=
mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
...
...
openfold/model/triangular_attention.py
View file @
722a5e01
...
...
@@ -62,6 +62,7 @@ class TriangleAttention(nn.Module):
x
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
mha_inputs
=
{
"q_x"
:
x
,
...
...
@@ -69,7 +70,7 @@ class TriangleAttention(nn.Module):
"biases"
:
biases
,
}
return
chunk_layer
(
partial
(
self
.
mha
),
partial
(
self
.
mha
,
use_lma
=
use_lma
),
mha_inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
...
...
@@ -78,7 +79,8 @@ class TriangleAttention(nn.Module):
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -113,9 +115,9 @@ class TriangleAttention(nn.Module):
biases
=
[
mask_bias
,
triangle_bias
]
if
chunk_size
is
not
None
:
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
)
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
,
use_lma
=
use_lma
)
else
:
x
=
self
.
mha
(
q_x
=
x
,
kv_x
=
x
,
biases
=
biases
)
x
=
self
.
mha
(
q_x
=
x
,
kv_x
=
x
,
biases
=
biases
,
use_lma
=
use_lma
)
if
not
self
.
starting
:
x
=
x
.
transpose
(
-
2
,
-
3
)
...
...
run_pretrained_openfold.py
View file @
722a5e01
...
...
@@ -15,6 +15,7 @@
import
argparse
from
datetime
import
date
import
gc
import
logging
import
numpy
as
np
import
os
...
...
@@ -76,18 +77,21 @@ def main(args):
else
:
alignment_dir
=
args
.
use_precomputed_alignments
# Gather input sequences
with
open
(
args
.
fasta_path
,
"r"
)
as
fp
:
data
=
fp
.
read
()
for
fasta_file
in
os
.
listdir
(
args
.
fasta_dir
):
# Gather input sequences
with
open
(
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
data
=
fp
.
read
()
lines
=
[
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
lines
=
[
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
for
tag
,
seq
in
zip
(
tags
,
seqs
):
fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
"tmp.fasta"
)
assert
len
(
seqs
)
==
1
,
"Input FASTAs may only contain one sequence"
tag
,
seq
=
tags
[
0
],
seqs
[
0
]
fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
with
open
(
fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
...
...
@@ -123,7 +127,7 @@ def main(args):
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
)
logging
.
info
(
"Executing model..."
)
batch
=
processed_feature_dict
with
torch
.
no_grad
():
...
...
@@ -160,27 +164,28 @@ def main(args):
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
**
config
.
relax
,
)
# Relax the prediction.
t
=
time
.
perf_counter
()
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
if
(
"cuda"
in
args
.
model_device
):
device_no
=
args
.
model_device
.
split
(
":"
)[
-
1
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
visible_devices
logging
.
info
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
relaxed_pdb_str
)
if
(
not
args
.
skip_relaxation
):
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
**
config
.
relax
,
)
# Relax the prediction.
t
=
time
.
perf_counter
()
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
if
(
"cuda"
in
args
.
model_device
):
device_no
=
args
.
model_device
.
split
(
":"
)[
-
1
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
visible_devices
logging
.
info
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
relaxed_pdb_str
)
if
(
args
.
save_outputs
):
output_dict_path
=
os
.
path
.
join
(
...
...
@@ -193,7 +198,8 @@ def main(args):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"fasta_path"
,
type
=
str
,
"fasta_dir"
,
type
=
str
,
help
=
"Path to directory containing FASTA files, one sequence per file"
)
parser
.
add_argument
(
"template_mmcif_dir"
,
type
=
str
,
...
...
@@ -224,7 +230,7 @@ if __name__ == "__main__":
openfold/resources/params"""
)
parser
.
add_argument
(
"--save_outputs"
,
type
=
bool
,
default
=
False
,
"--save_outputs"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to save all model outputs, including embeddings, etc."
)
parser
.
add_argument
(
...
...
@@ -232,11 +238,14 @@ if __name__ == "__main__":
help
=
"""Number of CPUs with which to run alignment tools"""
)
parser
.
add_argument
(
'
--preset
'
,
type
=
str
,
default
=
'full_dbs'
,
"
--preset
"
,
type
=
str
,
default
=
'full_dbs'
,
choices
=
(
'reduced_dbs'
,
'full_dbs'
)
)
parser
.
add_argument
(
'--data_random_seed'
,
type
=
str
,
default
=
None
"--data_random_seed"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--skip_relaxation"
,
action
=
"store_true"
,
default
=
False
,
)
add_data_args
(
parser
)
args
=
parser
.
parse_args
()
...
...
tests/test_primitives.py
View file @
722a5e01
...
...
@@ -18,7 +18,6 @@ import unittest
from
openfold.model.primitives
import
(
Attention
,
LowMemoryAttention
,
)
from
tests.config
import
consts
...
...
@@ -31,8 +30,7 @@ class TestLMA(unittest.TestCase):
no_heads
=
4
q
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
k
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
v
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
kv
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
bias
=
[
torch
.
rand
(
no_heads
,
1
,
n
)]
bias
=
[
b
.
cuda
()
for
b
in
bias
]
...
...
@@ -40,28 +38,13 @@ class TestLMA(unittest.TestCase):
gating_fill
=
torch
.
rand
(
c_hidden
*
no_heads
,
c_hidden
)
o_fill
=
torch
.
rand
(
c_hidden
,
c_hidden
*
no_heads
)
lma
=
LowMemoryAttention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
a
=
Attention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
with
torch
.
no_grad
():
for
n
,
p
in
lma
.
named_parameters
():
attrs
=
n
.
split
(
'.'
)
param
=
a
for
attr
in
attrs
:
param
=
getattr
(
param
,
attr
)
param
.
copy_
(
p
)
for
m
in
[
lma
,
a
]:
m
.
linear_g
.
weight
.
copy_
(
gating_fill
)
m
.
linear_o
.
weight
.
copy_
(
o_fill
)
with
torch
.
no_grad
():
l
=
lma
(
q
,
k
,
v
,
1024
,
4096
,
biases
=
bias
)
real
=
a
(
q
,
k
,
v
,
biases
=
bias
)
l
=
a
(
q
,
kv
,
biases
=
bias
,
use_lma
=
True
)
real
=
a
(
q
,
kv
,
biases
=
bias
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
consts
.
eps
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment