Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
722a5e01
Commit
722a5e01
authored
May 12, 2022
by
Gustaf Ahdritz
Browse files
Improve ease of use of LMA
parent
237e26c4
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
208 additions
and
114 deletions
+208
-114
openfold/config.py
openfold/config.py
+2
-0
openfold/model/evoformer.py
openfold/model/evoformer.py
+70
-12
openfold/model/model.py
openfold/model/model.py
+4
-0
openfold/model/msa.py
openfold/model/msa.py
+20
-8
openfold/model/primitives.py
openfold/model/primitives.py
+45
-28
openfold/model/template.py
openfold/model/template.py
+13
-6
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+6
-4
run_pretrained_openfold.py
run_pretrained_openfold.py
+45
-36
tests/test_primitives.py
tests/test_primitives.py
+3
-20
No files found.
openfold/config.py
View file @
722a5e01
...
@@ -80,6 +80,7 @@ def model_config(name, train=False, low_prec=False):
...
@@ -80,6 +80,7 @@ def model_config(name, train=False, low_prec=False):
if
train
:
if
train
:
c
.
globals
.
blocks_per_ckpt
=
1
c
.
globals
.
blocks_per_ckpt
=
1
c
.
globals
.
chunk_size
=
None
c
.
globals
.
chunk_size
=
None
c
.
globals
.
use_lma
=
False
if
low_prec
:
if
low_prec
:
c
.
globals
.
eps
=
1e-4
c
.
globals
.
eps
=
1e-4
...
@@ -269,6 +270,7 @@ config = mlc.ConfigDict(
...
@@ -269,6 +270,7 @@ config = mlc.ConfigDict(
"globals"
:
{
"globals"
:
{
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"chunk_size"
:
chunk_size
,
"use_lma"
:
False
,
"c_z"
:
c_z
,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"c_m"
:
c_m
,
"c_t"
:
c_t
,
"c_t"
:
c_t
,
...
...
openfold/model/evoformer.py
View file @
722a5e01
...
@@ -183,6 +183,7 @@ class EvoformerBlockCore(nn.Module):
...
@@ -183,6 +183,7 @@ class EvoformerBlockCore(nn.Module):
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# DeepMind doesn't mask these transitions in the source, so _mask_trans
...
@@ -192,21 +193,31 @@ class EvoformerBlockCore(nn.Module):
...
@@ -192,21 +193,31 @@ class EvoformerBlockCore(nn.Module):
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
m
=
m
+
self
.
msa_transition
(
m
=
m
+
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
,
)
)
z
=
z
+
self
.
outer_product_mean
(
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
)
)
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_att_start
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
)
self
.
tri_att_start
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
)
)
)
z
=
z
+
self
.
ps_dropout_col_layer
(
z
=
z
+
self
.
ps_dropout_col_layer
(
self
.
tri_att_end
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
)
self
.
tri_att_end
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
)
)
)
z
=
z
+
self
.
pair_transition
(
z
=
z
+
self
.
pair_transition
(
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
,
)
)
return
m
,
z
return
m
,
z
...
@@ -267,18 +278,31 @@ class EvoformerBlock(nn.Module):
...
@@ -267,18 +278,31 @@ class EvoformerBlock(nn.Module):
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
m
=
m
+
self
.
msa_dropout_layer
(
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
)
)
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
)
)
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
,
z
=
self
.
core
(
m
,
z
=
self
.
core
(
m
,
m
,
z
,
z
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
)
)
...
@@ -350,7 +374,9 @@ class ExtraMSABlock(nn.Module):
...
@@ -350,7 +374,9 @@ class ExtraMSABlock(nn.Module):
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
_chunk_logits
:
Optional
[
int
]
=
1024
,
_chunk_logits
:
Optional
[
int
]
=
1024
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
add
(
m1
,
m2
):
def
add
(
m1
,
m2
):
# The first operation in a checkpoint can't be in-place, but it's
# The first operation in a checkpoint can't be in-place, but it's
...
@@ -368,7 +394,8 @@ class ExtraMSABlock(nn.Module):
...
@@ -368,7 +394,8 @@ class ExtraMSABlock(nn.Module):
z
=
z
.
clone
()
if
torch
.
is_grad_enabled
()
else
z
,
z
=
z
.
clone
()
if
torch
.
is_grad_enabled
()
else
z
,
mask
=
msa_mask
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_memory_efficient_kernel
=
not
_chunk_logits
,
use_lma
=
use_lma
,
use_memory_efficient_kernel
=
not
_chunk_logits
and
not
use_lma
,
_chunk_logits
=
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
_chunk_logits
=
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
_checkpoint_chunks
=
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
...
@@ -376,9 +403,23 @@ class ExtraMSABlock(nn.Module):
...
@@ -376,9 +403,23 @@ class ExtraMSABlock(nn.Module):
))
))
def
fn
(
m
,
z
):
def
fn
(
m
,
z
):
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
))
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
)
)
m
,
z
=
self
.
core
(
m
,
z
=
self
.
core
(
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
)
)
return
m
,
z
return
m
,
z
...
@@ -488,6 +529,7 @@ class EvoformerStack(nn.Module):
...
@@ -488,6 +529,7 @@ class EvoformerStack(nn.Module):
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
"""
...
@@ -500,6 +542,8 @@ class EvoformerStack(nn.Module):
...
@@ -500,6 +542,8 @@ class EvoformerStack(nn.Module):
[*, N_seq, N_res] MSA mask
[*, N_seq, N_res] MSA mask
pair_mask:
pair_mask:
[*, N_res, N_res] pair mask
[*, N_res, N_res] pair mask
chunk_size: Inference-time subbatch size
use_lma: Whether to use low-memory attention during inference
Returns:
Returns:
m:
m:
[*, N_seq, N_res, C_m] MSA embedding
[*, N_seq, N_res, C_m] MSA embedding
...
@@ -514,6 +558,7 @@ class EvoformerStack(nn.Module):
...
@@ -514,6 +558,7 @@ class EvoformerStack(nn.Module):
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
)
)
for
b
in
self
.
blocks
for
b
in
self
.
blocks
...
@@ -591,6 +636,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -591,6 +636,7 @@ class ExtraMSAStack(nn.Module):
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
...
@@ -601,6 +647,8 @@ class ExtraMSAStack(nn.Module):
...
@@ -601,6 +647,8 @@ class ExtraMSAStack(nn.Module):
[*, N_extra, N_res, C_m] extra MSA embedding
[*, N_extra, N_res, C_m] extra MSA embedding
z:
z:
[*, N_res, N_res, C_z] pair embedding
[*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules
use_lma: Whether to use low-memory attention during inference
msa_mask:
msa_mask:
Optional [*, N_extra, N_res] MSA mask
Optional [*, N_extra, N_res] MSA mask
pair_mask:
pair_mask:
...
@@ -616,7 +664,9 @@ class ExtraMSAStack(nn.Module):
...
@@ -616,7 +664,9 @@ class ExtraMSAStack(nn.Module):
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
_chunk_logits
=
None
use_lma
=
use_lma
,
_chunk_logits
=
None
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
)
for
b
in
self
.
blocks
]
]
...
@@ -634,7 +684,15 @@ class ExtraMSAStack(nn.Module):
...
@@ -634,7 +684,15 @@ class ExtraMSAStack(nn.Module):
m
,
z
=
b
(
m
,
z
)
m
,
z
=
b
(
m
,
z
)
else
:
else
:
for
b
in
self
.
blocks
:
for
b
in
self
.
blocks
:
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
)
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
)
if
(
self
.
clear_cache_between_blocks
):
if
(
self
.
clear_cache_between_blocks
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
openfold/model/model.py
View file @
722a5e01
...
@@ -152,6 +152,7 @@ class AlphaFold(nn.Module):
...
@@ -152,6 +152,7 @@ class AlphaFold(nn.Module):
template_embeds
[
"pair"
],
template_embeds
[
"pair"
],
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
...
@@ -161,6 +162,7 @@ class AlphaFold(nn.Module):
...
@@ -161,6 +162,7 @@ class AlphaFold(nn.Module):
z
,
z
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
)
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
...
@@ -294,6 +296,7 @@ class AlphaFold(nn.Module):
...
@@ -294,6 +296,7 @@ class AlphaFold(nn.Module):
z
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
a
.
dtype
),
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
a
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
...
@@ -308,6 +311,7 @@ class AlphaFold(nn.Module):
...
@@ -308,6 +311,7 @@ class AlphaFold(nn.Module):
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
),
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
...
...
openfold/model/msa.py
View file @
722a5e01
...
@@ -90,12 +90,14 @@ class MSAAttention(nn.Module):
...
@@ -90,12 +90,14 @@ class MSAAttention(nn.Module):
def
_chunk
(
self
,
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
biases
:
List
[
torch
.
Tensor
],
use_memory_efficient_kernel
:
bool
,
chunk_size
:
int
,
chunk_size
:
int
,
use_memory_efficient_kernel
:
bool
,
use_lma
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
mha
=
partial
(
mha
=
partial
(
self
.
mha
,
self
.
mha
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
)
)
return
chunk_layer
(
return
chunk_layer
(
mha
,
mha
,
...
@@ -193,6 +195,7 @@ class MSAAttention(nn.Module):
...
@@ -193,6 +195,7 @@ class MSAAttention(nn.Module):
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -224,13 +227,20 @@ class MSAAttention(nn.Module):
...
@@ -224,13 +227,20 @@ class MSAAttention(nn.Module):
biases
.
append
(
z
)
biases
.
append
(
z
)
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
biases
,
use_memory_efficient_kernel
,
chunk_size
)
m
=
self
.
_chunk
(
m
,
biases
,
chunk_size
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
)
else
:
else
:
m
=
self
.
mha
(
m
=
self
.
mha
(
q_x
=
m
,
q_x
=
m
,
kv_x
=
m
,
kv_x
=
m
,
biases
=
biases
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
)
)
return
m
return
m
...
@@ -305,7 +315,7 @@ class MSAColumnAttention(nn.Module):
...
@@ -305,7 +315,7 @@ class MSAColumnAttention(nn.Module):
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_
memory_efficient_kernel
:
bool
=
False
,
use_
lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -323,7 +333,7 @@ class MSAColumnAttention(nn.Module):
...
@@ -323,7 +333,7 @@ class MSAColumnAttention(nn.Module):
if
mask
is
not
None
:
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
2
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
m
=
self
.
_msa_att
(
m
,
mask
=
mask
,
chunk_size
=
chunk_size
)
m
=
self
.
_msa_att
(
m
,
mask
=
mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
)
# [*, N_seq, N_res, C_in]
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
...
@@ -360,13 +370,14 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -360,13 +370,14 @@ class MSAColumnGlobalAttention(nn.Module):
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
mha_input
=
{
mha_input
=
{
"m"
:
m
,
"m"
:
m
,
"mask"
:
mask
,
"mask"
:
mask
,
}
}
return
chunk_layer
(
return
chunk_layer
(
self
.
global_attention
,
partial
(
self
.
global_attention
,
use_lma
=
use_lma
),
mha_input
,
mha_input
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
...
@@ -377,6 +388,7 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -377,6 +388,7 @@ class MSAColumnGlobalAttention(nn.Module):
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
...
@@ -396,9 +408,9 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -396,9 +408,9 @@ class MSAColumnGlobalAttention(nn.Module):
m
=
self
.
layer_norm_m
(
m
)
m
=
self
.
layer_norm_m
(
m
)
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
)
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
,
use_lma
=
use_lma
)
else
:
else
:
m
=
self
.
global_attention
(
m
=
m
,
mask
=
mask
)
m
=
self
.
global_attention
(
m
=
m
,
mask
=
mask
,
use_lma
=
use_lma
)
# [*, N_seq, N_res, C_in]
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
m
.
transpose
(
-
2
,
-
3
)
...
...
openfold/model/primitives.py
View file @
722a5e01
...
@@ -31,6 +31,10 @@ from openfold.utils.tensor_utils import (
...
@@ -31,6 +31,10 @@ from openfold.utils.tensor_utils import (
)
)
DEFAULT_LMA_Q_CHUNK_SIZE
=
1024
DEFAULT_LMA_KV_CHUNK_SIZE
=
4096
def
_prod
(
nums
):
def
_prod
(
nums
):
out
=
1
out
=
1
for
n
in
nums
:
for
n
in
nums
:
...
@@ -403,8 +407,8 @@ class Attention(nn.Module):
...
@@ -403,8 +407,8 @@ class Attention(nn.Module):
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
q_chunk_size
:
Optional
[
int
]
=
None
,
q_chunk_size
:
int
=
DEFAULT_LMA_Q_CHUNK_SIZE
,
kv_chunk_size
:
Optional
[
int
]
=
None
,
kv_chunk_size
:
int
=
DEFAULT_LMA_KV_CHUNK_SIZE
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -460,6 +464,7 @@ class Attention(nn.Module):
...
@@ -460,6 +464,7 @@ class Attention(nn.Module):
for
b
in
biases
for
b
in
biases
]
]
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
o
=
o
.
transpose
(
-
2
,
-
3
)
else
:
else
:
o
=
_attention
(
q
,
k
,
v
,
biases
)
o
=
_attention
(
q
,
k
,
v
,
biases
)
o
=
o
.
transpose
(
-
2
,
-
3
)
o
=
o
.
transpose
(
-
2
,
-
3
)
...
@@ -494,7 +499,11 @@ class GlobalAttention(nn.Module):
...
@@ -494,7 +499,11 @@ class GlobalAttention(nn.Module):
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
sigmoid
=
nn
.
Sigmoid
()
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
# [*, N_res, C_in]
# [*, N_res, C_in]
q
=
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
q
=
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
...
@@ -511,12 +520,13 @@ class GlobalAttention(nn.Module):
...
@@ -511,12 +520,13 @@ class GlobalAttention(nn.Module):
k
=
self
.
linear_k
(
m
)
k
=
self
.
linear_k
(
m
)
v
=
self
.
linear_v
(
m
)
v
=
self
.
linear_v
(
m
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
if
(
not
use_lma
):
# [*, N_res, H, N_seq]
# [*, N_res, H, N_seq]
a
=
torch
.
matmul
(
a
=
torch
.
matmul
(
q
,
q
,
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
)
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
+=
bias
a
+=
bias
a
=
softmax_no_cast
(
a
)
a
=
softmax_no_cast
(
a
)
...
@@ -525,6 +535,15 @@ class GlobalAttention(nn.Module):
...
@@ -525,6 +535,15 @@ class GlobalAttention(nn.Module):
a
,
a
,
v
,
v
,
)
)
else
:
o
=
_lma
(
q
,
k
,
v
,
[
bias
],
DEFAULT_LMA_Q_CHUNK_SIZE
,
DEFAULT_LMA_KV_CHUNK_SIZE
)
# [*, N_res, N_seq, C_hidden]
# [*, N_res, N_seq, C_hidden]
g
=
self
.
sigmoid
(
self
.
linear_g
(
m
))
g
=
self
.
sigmoid
(
self
.
linear_g
(
m
))
...
@@ -552,12 +571,12 @@ def _lma(
...
@@ -552,12 +571,12 @@ def _lma(
q_chunk_size
:
int
,
q_chunk_size
:
int
,
kv_chunk_size
:
int
,
kv_chunk_size
:
int
,
):
):
no_q
,
no_kv
=
q
.
shape
[
-
3
],
k
.
shape
[
-
3
]
no_q
,
no_kv
=
q
.
shape
[
-
2
],
k
.
shape
[
-
2
]
# [*,
Q
,
H
, C_hidden]
# [*,
H
,
Q
, C_hidden]
o
=
q
.
new_zeros
(
q
.
shape
)
o
=
q
.
new_zeros
(
q
.
shape
)
for
q_s
in
range
(
0
,
no_q
,
q_chunk_size
):
for
q_s
in
range
(
0
,
no_q
,
q_chunk_size
):
q_chunk
=
q
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
q_chunk
=
q
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
large_bias_chunks
=
[
large_bias_chunks
=
[
b
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
for
b
in
biases
b
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
for
b
in
biases
]
]
...
@@ -566,24 +585,22 @@ def _lma(
...
@@ -566,24 +585,22 @@ def _lma(
weights
=
[]
weights
=
[]
values
=
[]
values
=
[]
for
kv_s
in
range
(
0
,
no_kv
,
kv_chunk_size
):
for
kv_s
in
range
(
0
,
no_kv
,
kv_chunk_size
):
k_chunk
=
k
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
k_chunk
=
k
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:]
v_chunk
=
v
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
v_chunk
=
v
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:]
small_bias_chunks
=
[
small_bias_chunks
=
[
b
[...,
kv_s
:
kv_s
+
kv_chunk_size
]
for
b
in
large_bias_chunks
b
[...,
kv_s
:
kv_s
+
kv_chunk_size
]
for
b
in
large_bias_chunks
]
]
a
=
torch
.
einsum
(
a
=
torch
.
einsum
(
"...
q
hd,...
k
hd->...hqk"
,
q_chunk
,
k_chunk
,
"...h
q
d,...h
k
d->...hqk"
,
q_chunk
,
k_chunk
,
)
)
for
b
in
small_bias_chunks
:
for
b
in
small_bias_chunks
:
a
+=
b
a
+=
b
a
=
a
.
transpose
(
-
2
,
-
3
)
max_a
=
torch
.
max
(
a
,
dim
=-
1
,
keepdim
=
True
)[
0
]
max_a
=
torch
.
max
(
a
,
dim
=-
1
,
keepdim
=
True
)[
0
]
exp_a
=
torch
.
exp
(
a
-
max_a
)
exp_a
=
torch
.
exp
(
a
-
max_a
)
exp_v
=
torch
.
einsum
(
"...
v
hf,...
q
hv->...
q
hf"
,
v_chunk
,
exp_a
)
exp_v
=
torch
.
einsum
(
"...h
v
f,...h
q
v->...h
q
f"
,
v_chunk
,
exp_a
)
maxes
.
append
(
max_a
.
detach
().
squeeze
(
-
1
))
maxes
.
append
(
max_a
.
detach
().
squeeze
(
-
1
))
weights
.
append
(
torch
.
sum
(
exp_a
,
dim
=-
1
))
weights
.
append
(
torch
.
sum
(
exp_a
,
dim
=-
1
))
...
@@ -595,14 +612,14 @@ def _lma(
...
@@ -595,14 +612,14 @@ def _lma(
global_max
=
torch
.
max
(
chunk_max
,
dim
=-
3
,
keepdim
=
True
)[
0
]
global_max
=
torch
.
max
(
chunk_max
,
dim
=-
3
,
keepdim
=
True
)[
0
]
max_diffs
=
torch
.
exp
(
chunk_max
-
global_max
)
max_diffs
=
torch
.
exp
(
chunk_max
-
global_max
)
chunk_values
*
=
max_diffs
.
unsqueeze
(
-
1
)
chunk_values
=
chunk_values
*
max_diffs
.
unsqueeze
(
-
1
)
chunk_weights
*
=
max_diffs
chunk_weights
=
chunk_weights
*
max_diffs
all_values
=
torch
.
sum
(
chunk_values
,
dim
=-
4
)
all_values
=
torch
.
sum
(
chunk_values
,
dim
=-
4
)
all_weights
=
torch
.
sum
(
chunk_weights
.
unsqueeze
(
-
1
),
dim
=-
4
)
all_weights
=
torch
.
sum
(
chunk_weights
.
unsqueeze
(
-
1
),
dim
=-
4
)
q_chunk_out
=
all_values
/
all_weights
q_chunk_out
=
all_values
/
all_weights
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
=
q_chunk_out
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
=
q_chunk_out
return
o
return
o
openfold/model/template.py
View file @
722a5e01
...
@@ -77,6 +77,7 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -77,6 +77,7 @@ class TemplatePointwiseAttention(nn.Module):
t
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
mha_inputs
=
{
mha_inputs
=
{
"q_x"
:
z
,
"q_x"
:
z
,
...
@@ -84,7 +85,7 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -84,7 +85,7 @@ class TemplatePointwiseAttention(nn.Module):
"biases"
:
biases
,
"biases"
:
biases
,
}
}
return
chunk_layer
(
return
chunk_layer
(
self
.
mha
,
partial
(
self
.
mha
,
use_lma
=
use_lma
),
mha_inputs
,
mha_inputs
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
...
@@ -95,7 +96,8 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -95,7 +96,8 @@ class TemplatePointwiseAttention(nn.Module):
t
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
template_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
template_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -122,9 +124,9 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -122,9 +124,9 @@ class TemplatePointwiseAttention(nn.Module):
# [*, N_res, N_res, 1, C_z]
# [*, N_res, N_res, 1, C_z]
biases
=
[
bias
]
biases
=
[
bias
]
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
z
=
self
.
_chunk
(
z
,
t
,
biases
,
chunk_size
)
z
=
self
.
_chunk
(
z
,
t
,
biases
,
chunk_size
,
use_lma
=
use_lma
)
else
:
else
:
z
=
self
.
mha
(
q_x
=
z
,
kv_x
=
t
,
biases
=
biases
)
z
=
self
.
mha
(
q_x
=
z
,
kv_x
=
t
,
biases
=
biases
,
use_lma
=
use_lma
)
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
z
=
z
.
squeeze
(
-
2
)
z
=
z
.
squeeze
(
-
2
)
...
@@ -188,6 +190,7 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -188,6 +190,7 @@ class TemplatePairStackBlock(nn.Module):
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
_mask_trans
:
bool
=
True
):
):
single_templates
=
[
single_templates
=
[
...
@@ -204,14 +207,16 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -204,14 +207,16 @@ class TemplatePairStackBlock(nn.Module):
self
.
tri_att_start
(
self
.
tri_att_start
(
single
,
single
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
mask
=
single_mask
mask
=
single_mask
,
use_lma
=
use_lma
,
)
)
)
)
single
=
single
+
self
.
dropout_col
(
single
=
single
+
self
.
dropout_col
(
self
.
tri_att_end
(
self
.
tri_att_end
(
single
,
single
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
mask
=
single_mask
mask
=
single_mask
,
use_lma
=
use_lma
,
)
)
)
)
single
=
single
+
self
.
dropout_row
(
single
=
single
+
self
.
dropout_row
(
...
@@ -298,6 +303,7 @@ class TemplatePairStack(nn.Module):
...
@@ -298,6 +303,7 @@ class TemplatePairStack(nn.Module):
t
:
torch
.
tensor
,
t
:
torch
.
tensor
,
mask
:
torch
.
tensor
,
mask
:
torch
.
tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
):
):
"""
"""
...
@@ -320,6 +326,7 @@ class TemplatePairStack(nn.Module):
...
@@ -320,6 +326,7 @@ class TemplatePairStack(nn.Module):
b
,
b
,
mask
=
mask
,
mask
=
mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
)
)
for
b
in
self
.
blocks
for
b
in
self
.
blocks
...
...
openfold/model/triangular_attention.py
View file @
722a5e01
...
@@ -62,6 +62,7 @@ class TriangleAttention(nn.Module):
...
@@ -62,6 +62,7 @@ class TriangleAttention(nn.Module):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
mha_inputs
=
{
mha_inputs
=
{
"q_x"
:
x
,
"q_x"
:
x
,
...
@@ -69,7 +70,7 @@ class TriangleAttention(nn.Module):
...
@@ -69,7 +70,7 @@ class TriangleAttention(nn.Module):
"biases"
:
biases
,
"biases"
:
biases
,
}
}
return
chunk_layer
(
return
chunk_layer
(
partial
(
self
.
mha
),
partial
(
self
.
mha
,
use_lma
=
use_lma
),
mha_inputs
,
mha_inputs
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
...
@@ -78,7 +79,8 @@ class TriangleAttention(nn.Module):
...
@@ -78,7 +79,8 @@ class TriangleAttention(nn.Module):
def
forward
(
self
,
def
forward
(
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -113,9 +115,9 @@ class TriangleAttention(nn.Module):
...
@@ -113,9 +115,9 @@ class TriangleAttention(nn.Module):
biases
=
[
mask_bias
,
triangle_bias
]
biases
=
[
mask_bias
,
triangle_bias
]
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
)
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
,
use_lma
=
use_lma
)
else
:
else
:
x
=
self
.
mha
(
q_x
=
x
,
kv_x
=
x
,
biases
=
biases
)
x
=
self
.
mha
(
q_x
=
x
,
kv_x
=
x
,
biases
=
biases
,
use_lma
=
use_lma
)
if
not
self
.
starting
:
if
not
self
.
starting
:
x
=
x
.
transpose
(
-
2
,
-
3
)
x
=
x
.
transpose
(
-
2
,
-
3
)
...
...
run_pretrained_openfold.py
View file @
722a5e01
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
argparse
import
argparse
from
datetime
import
date
from
datetime
import
date
import
gc
import
logging
import
logging
import
numpy
as
np
import
numpy
as
np
import
os
import
os
...
@@ -76,8 +77,9 @@ def main(args):
...
@@ -76,8 +77,9 @@ def main(args):
else
:
else
:
alignment_dir
=
args
.
use_precomputed_alignments
alignment_dir
=
args
.
use_precomputed_alignments
for
fasta_file
in
os
.
listdir
(
args
.
fasta_dir
):
# Gather input sequences
# Gather input sequences
with
open
(
args
.
fasta_path
,
"r"
)
as
fp
:
with
open
(
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
)
,
"r"
)
as
fp
:
data
=
fp
.
read
()
data
=
fp
.
read
()
lines
=
[
lines
=
[
...
@@ -86,8 +88,10 @@ def main(args):
...
@@ -86,8 +88,10 @@ def main(args):
][
1
:]
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
for
tag
,
seq
in
zip
(
tags
,
seqs
):
assert
len
(
seqs
)
==
1
,
"Input FASTAs may only contain one sequence"
fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
"tmp.fasta"
)
tag
,
seq
=
tags
[
0
],
seqs
[
0
]
fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
with
open
(
fasta_path
,
"w"
)
as
fp
:
with
open
(
fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
...
@@ -160,6 +164,7 @@ def main(args):
...
@@ -160,6 +164,7 @@ def main(args):
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
if
(
not
args
.
skip_relaxation
):
amber_relaxer
=
relax
.
AmberRelaxation
(
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
**
config
.
relax
,
**
config
.
relax
,
...
@@ -193,7 +198,8 @@ def main(args):
...
@@ -193,7 +198,8 @@ def main(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
"fasta_path"
,
type
=
str
,
"fasta_dir"
,
type
=
str
,
help
=
"Path to directory containing FASTA files, one sequence per file"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"template_mmcif_dir"
,
type
=
str
,
"template_mmcif_dir"
,
type
=
str
,
...
@@ -224,7 +230,7 @@ if __name__ == "__main__":
...
@@ -224,7 +230,7 @@ if __name__ == "__main__":
openfold/resources/params"""
openfold/resources/params"""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--save_outputs"
,
type
=
bool
,
default
=
False
,
"--save_outputs"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to save all model outputs, including embeddings, etc."
help
=
"Whether to save all model outputs, including embeddings, etc."
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -232,11 +238,14 @@ if __name__ == "__main__":
...
@@ -232,11 +238,14 @@ if __name__ == "__main__":
help
=
"""Number of CPUs with which to run alignment tools"""
help
=
"""Number of CPUs with which to run alignment tools"""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'
--preset
'
,
type
=
str
,
default
=
'full_dbs'
,
"
--preset
"
,
type
=
str
,
default
=
'full_dbs'
,
choices
=
(
'reduced_dbs'
,
'full_dbs'
)
choices
=
(
'reduced_dbs'
,
'full_dbs'
)
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--data_random_seed'
,
type
=
str
,
default
=
None
"--data_random_seed"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--skip_relaxation"
,
action
=
"store_true"
,
default
=
False
,
)
)
add_data_args
(
parser
)
add_data_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
tests/test_primitives.py
View file @
722a5e01
...
@@ -18,7 +18,6 @@ import unittest
...
@@ -18,7 +18,6 @@ import unittest
from
openfold.model.primitives
import
(
from
openfold.model.primitives
import
(
Attention
,
Attention
,
LowMemoryAttention
,
)
)
from
tests.config
import
consts
from
tests.config
import
consts
...
@@ -31,8 +30,7 @@ class TestLMA(unittest.TestCase):
...
@@ -31,8 +30,7 @@ class TestLMA(unittest.TestCase):
no_heads
=
4
no_heads
=
4
q
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
q
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
k
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
kv
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
v
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
bias
=
[
torch
.
rand
(
no_heads
,
1
,
n
)]
bias
=
[
torch
.
rand
(
no_heads
,
1
,
n
)]
bias
=
[
b
.
cuda
()
for
b
in
bias
]
bias
=
[
b
.
cuda
()
for
b
in
bias
]
...
@@ -40,28 +38,13 @@ class TestLMA(unittest.TestCase):
...
@@ -40,28 +38,13 @@ class TestLMA(unittest.TestCase):
gating_fill
=
torch
.
rand
(
c_hidden
*
no_heads
,
c_hidden
)
gating_fill
=
torch
.
rand
(
c_hidden
*
no_heads
,
c_hidden
)
o_fill
=
torch
.
rand
(
c_hidden
,
c_hidden
*
no_heads
)
o_fill
=
torch
.
rand
(
c_hidden
,
c_hidden
*
no_heads
)
lma
=
LowMemoryAttention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
a
=
Attention
(
a
=
Attention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
).
cuda
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
n
,
p
in
lma
.
named_parameters
():
l
=
a
(
q
,
kv
,
biases
=
bias
,
use_lma
=
True
)
attrs
=
n
.
split
(
'.'
)
real
=
a
(
q
,
kv
,
biases
=
bias
)
param
=
a
for
attr
in
attrs
:
param
=
getattr
(
param
,
attr
)
param
.
copy_
(
p
)
for
m
in
[
lma
,
a
]:
m
.
linear_g
.
weight
.
copy_
(
gating_fill
)
m
.
linear_o
.
weight
.
copy_
(
o_fill
)
with
torch
.
no_grad
():
l
=
lma
(
q
,
k
,
v
,
1024
,
4096
,
biases
=
bias
)
real
=
a
(
q
,
k
,
v
,
biases
=
bias
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
consts
.
eps
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment