Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
0d148a7d
Commit
0d148a7d
authored
Jul 08, 2022
by
Gustaf Ahdritz
Browse files
Add FlashAttention support to msa.py
parent
bcb0b70f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
16 deletions
+44
-16
openfold/model/msa.py
openfold/model/msa.py
+44
-16
No files found.
openfold/model/msa.py
View file @
0d148a7d
...
...
@@ -89,12 +89,14 @@ class MSAAttention(nn.Module):
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
biases
:
Optional
[
List
[
torch
.
Tensor
]
]
,
chunk_size
:
int
,
use_memory_efficient_kernel
:
bool
,
use_lma
:
bool
,
use_flash
:
bool
,
flash_mask
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
def
fn
(
m
,
biases
):
def
fn
(
m
,
biases
,
flash_mask
):
m
=
self
.
layer_norm_m
(
m
)
return
self
.
mha
(
q_x
=
m
,
...
...
@@ -102,14 +104,23 @@ class MSAAttention(nn.Module):
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
flash_mask
=
flash_mask
,
)
inputs
=
{
"m"
:
m
}
if
(
biases
is
not
None
):
inputs
[
"biases"
]
=
biases
else
:
fn
=
partial
(
fn
,
biases
=
None
)
if
(
use_flash
and
flash_mask
is
not
None
):
inputs
[
"flash_mask"
]
=
flash_mask
else
:
fn
=
partial
(
fn
,
flash_mask
=
None
)
return
chunk_layer
(
fn
,
{
"m"
:
m
,
"biases"
:
biases
,
},
inputs
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
])
)
...
...
@@ -175,6 +186,7 @@ class MSAAttention(nn.Module):
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
,
inplace_safe
=
inplace_safe
)
m
=
self
.
layer_norm_m
(
m
)
q
,
k
,
v
=
self
.
mha
.
_prep_qkv
(
m
,
m
)
return
m
,
q
,
k
,
v
,
mask_bias
,
z
...
...
@@ -210,6 +222,7 @@ class MSAAttention(nn.Module):
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
...
...
@@ -235,15 +248,19 @@ class MSAAttention(nn.Module):
chunk_logits
=
_chunk_logits
,
checkpoint
=
_checkpoint_chunks
,
inplace_safe
=
inplace_safe
,
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
,
inplace_safe
=
inplace_safe
)
biases
=
[
mask_bias
]
if
(
z
is
not
None
):
biases
.
append
(
z
)
)
if
(
use_flash
):
assert
z
is
None
biases
=
None
else
:
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
,
inplace_safe
=
inplace_safe
)
biases
=
[
mask_bias
]
if
(
z
is
not
None
):
biases
.
append
(
z
)
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
...
...
@@ -252,6 +269,8 @@ class MSAAttention(nn.Module):
chunk_size
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
flash_mask
=
mask
,
)
else
:
m
=
self
.
layer_norm_m
(
m
)
...
...
@@ -261,6 +280,8 @@ class MSAAttention(nn.Module):
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
flash_mask
=
mask
,
)
return
m
...
...
@@ -336,6 +357,7 @@ class MSAColumnAttention(nn.Module):
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -353,7 +375,13 @@ class MSAColumnAttention(nn.Module):
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
2
)
m
=
self
.
_msa_att
(
m
,
mask
=
mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
)
m
=
self
.
_msa_att
(
m
,
mask
=
mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
)
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment