Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
f0a320e0
Commit
f0a320e0
authored
Sep 07, 2023
by
Christina Floristean
Browse files
Integrated deepspeed attention kernel and added initial tests.
parent
2134cc09
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
254 additions
and
37 deletions
+254
-37
deepspeed_config.json
deepspeed_config.json
+10
-1
openfold/config.py
openfold/config.py
+1
-0
openfold/model/evoformer.py
openfold/model/evoformer.py
+25
-3
openfold/model/model.py
openfold/model/model.py
+4
-0
openfold/model/msa.py
openfold/model/msa.py
+10
-3
openfold/model/primitives.py
openfold/model/primitives.py
+57
-29
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+5
-0
tests/config.py
tests/config.py
+1
-1
tests/test_deepspeed_evo_attention.py
tests/test_deepspeed_evo_attention.py
+141
-0
No files found.
deepspeed_config.json
View file @
f0a320e0
...
@@ -10,9 +10,18 @@
...
@@ -10,9 +10,18 @@
"bfloat16"
:
{
"bfloat16"
:
{
"enabled"
:
true
"enabled"
:
true
},
},
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
1e-3
,
"eps"
:
1e-5
}
},
"zero_optimization"
:
{
"zero_optimization"
:
{
"stage"
:
2
,
"stage"
:
2
,
"cpu_offload"
:
true
,
"offload_optimizer"
:
{
"device"
:
"cpu"
},
"contiguous_gradients"
:
true
"contiguous_gradients"
:
true
},
},
"activation_checkpointing"
:
{
"activation_checkpointing"
:
{
...
...
openfold/config.py
View file @
f0a320e0
...
@@ -367,6 +367,7 @@ config = mlc.ConfigDict(
...
@@ -367,6 +367,7 @@ config = mlc.ConfigDict(
"globals"
:
{
"globals"
:
{
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"chunk_size"
:
chunk_size
,
"use_deepspeed_evo_attention"
:
False
,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
# Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash.
# exclusive with use_flash.
"use_lma"
:
False
,
"use_lma"
:
False
,
...
...
openfold/model/evoformer.py
View file @
f0a320e0
...
@@ -181,6 +181,7 @@ class EvoformerBlockCore(nn.Module):
...
@@ -181,6 +181,7 @@ class EvoformerBlockCore(nn.Module):
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
...
@@ -260,6 +261,7 @@ class EvoformerBlockCore(nn.Module):
...
@@ -260,6 +261,7 @@ class EvoformerBlockCore(nn.Module):
mask
=
pair_mask
,
mask
=
pair_mask
,
chunk_size
=
_attn_chunk_size
,
chunk_size
=
_attn_chunk_size
,
use_memory_efficient_kernel
=
False
,
use_memory_efficient_kernel
=
False
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
inplace_safe
=
inplace_safe
,
)
)
...
@@ -279,6 +281,7 @@ class EvoformerBlockCore(nn.Module):
...
@@ -279,6 +281,7 @@ class EvoformerBlockCore(nn.Module):
mask
=
pair_mask
.
transpose
(
-
1
,
-
2
),
mask
=
pair_mask
.
transpose
(
-
1
,
-
2
),
chunk_size
=
_attn_chunk_size
,
chunk_size
=
_attn_chunk_size
,
use_memory_efficient_kernel
=
False
,
use_memory_efficient_kernel
=
False
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
inplace_safe
=
inplace_safe
,
)
)
...
@@ -365,6 +368,7 @@ class EvoformerBlock(nn.Module):
...
@@ -365,6 +368,7 @@ class EvoformerBlock(nn.Module):
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
...
@@ -392,6 +396,7 @@ class EvoformerBlock(nn.Module):
...
@@ -392,6 +396,7 @@ class EvoformerBlock(nn.Module):
mask
=
msa_mask
,
mask
=
msa_mask
,
chunk_size
=
_attn_chunk_size
,
chunk_size
=
_attn_chunk_size
,
use_memory_efficient_kernel
=
False
,
use_memory_efficient_kernel
=
False
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
)
)
),
),
...
@@ -403,6 +408,7 @@ class EvoformerBlock(nn.Module):
...
@@ -403,6 +408,7 @@ class EvoformerBlock(nn.Module):
m
,
m
,
mask
=
msa_mask
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
use_flash
=
use_flash
,
),
),
...
@@ -418,7 +424,8 @@ class EvoformerBlock(nn.Module):
...
@@ -418,7 +424,8 @@ class EvoformerBlock(nn.Module):
input_tensors
,
input_tensors
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
...
@@ -494,6 +501,7 @@ class ExtraMSABlock(nn.Module):
...
@@ -494,6 +501,7 @@ class ExtraMSABlock(nn.Module):
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
...
@@ -520,7 +528,8 @@ class ExtraMSABlock(nn.Module):
...
@@ -520,7 +528,8 @@ class ExtraMSABlock(nn.Module):
mask
=
msa_mask
,
mask
=
msa_mask
,
chunk_size
=
_attn_chunk_size
,
chunk_size
=
_attn_chunk_size
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
use_memory_efficient_kernel
=
not
use_lma
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_memory_efficient_kernel
=
not
(
use_lma
or
use_deepspeed_evo_attention
),
_checkpoint_chunks
=
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
)
)
...
@@ -554,6 +563,7 @@ class ExtraMSABlock(nn.Module):
...
@@ -554,6 +563,7 @@ class ExtraMSABlock(nn.Module):
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
...
@@ -674,6 +684,7 @@ class EvoformerStack(nn.Module):
...
@@ -674,6 +684,7 @@ class EvoformerStack(nn.Module):
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_deepspeed_evo_attention
:
bool
,
use_lma
:
bool
,
use_lma
:
bool
,
use_flash
:
bool
,
use_flash
:
bool
,
msa_mask
:
Optional
[
torch
.
Tensor
],
msa_mask
:
Optional
[
torch
.
Tensor
],
...
@@ -687,6 +698,7 @@ class EvoformerStack(nn.Module):
...
@@ -687,6 +698,7 @@ class EvoformerStack(nn.Module):
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
use_flash
=
use_flash
,
inplace_safe
=
inplace_safe
,
inplace_safe
=
inplace_safe
,
...
@@ -726,6 +738,7 @@ class EvoformerStack(nn.Module):
...
@@ -726,6 +738,7 @@ class EvoformerStack(nn.Module):
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
use_flash
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
...
@@ -737,6 +750,7 @@ class EvoformerStack(nn.Module):
...
@@ -737,6 +750,7 @@ class EvoformerStack(nn.Module):
m
=
input_tensors
[
0
],
m
=
input_tensors
[
0
],
z
=
input_tensors
[
1
],
z
=
input_tensors
[
1
],
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
use_flash
=
use_flash
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
...
@@ -768,6 +782,7 @@ class EvoformerStack(nn.Module):
...
@@ -768,6 +782,7 @@ class EvoformerStack(nn.Module):
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
...
@@ -802,6 +817,7 @@ class EvoformerStack(nn.Module):
...
@@ -802,6 +817,7 @@ class EvoformerStack(nn.Module):
m
=
m
,
m
=
m
,
z
=
z
,
z
=
z
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
use_flash
=
use_flash
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
...
@@ -882,6 +898,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -882,6 +898,7 @@ class ExtraMSAStack(nn.Module):
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_deepspeed_evo_attention
:
bool
,
use_lma
:
bool
,
use_lma
:
bool
,
msa_mask
:
Optional
[
torch
.
Tensor
],
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
...
@@ -893,7 +910,8 @@ class ExtraMSAStack(nn.Module):
...
@@ -893,7 +910,8 @@ class ExtraMSAStack(nn.Module):
b
,
b
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
...
@@ -930,6 +948,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -930,6 +948,7 @@ class ExtraMSAStack(nn.Module):
def
_forward_offload
(
self
,
def
_forward_offload
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
input_tensors
:
Sequence
[
torch
.
Tensor
],
chunk_size
:
int
,
chunk_size
:
int
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -942,6 +961,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -942,6 +961,7 @@ class ExtraMSAStack(nn.Module):
m
=
input_tensors
[
0
],
m
=
input_tensors
[
0
],
z
=
input_tensors
[
1
],
z
=
input_tensors
[
1
],
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
...
@@ -968,6 +988,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -968,6 +988,7 @@ class ExtraMSAStack(nn.Module):
msa_mask
:
Optional
[
torch
.
Tensor
],
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
chunk_size
:
int
,
chunk_size
:
int
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
...
@@ -992,6 +1013,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -992,6 +1013,7 @@ class ExtraMSAStack(nn.Module):
m
=
m
,
m
=
m
,
z
=
z
,
z
=
z
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
...
...
openfold/model/model.py
View file @
f0a320e0
...
@@ -355,6 +355,7 @@ class AlphaFold(nn.Module):
...
@@ -355,6 +355,7 @@ class AlphaFold(nn.Module):
input_tensors
,
input_tensors
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_deepspeed_evo_attention
=
self
.
globals
.
use_deepspeed_evo_attention
,
use_lma
=
self
.
globals
.
use_lma
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
...
@@ -367,6 +368,7 @@ class AlphaFold(nn.Module):
...
@@ -367,6 +368,7 @@ class AlphaFold(nn.Module):
a
,
z
,
a
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_deepspeed_evo_attention
=
self
.
globals
.
use_deepspeed_evo_attention
,
use_lma
=
self
.
globals
.
use_lma
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
inplace_safe
=
inplace_safe
,
inplace_safe
=
inplace_safe
,
...
@@ -385,6 +387,7 @@ class AlphaFold(nn.Module):
...
@@ -385,6 +387,7 @@ class AlphaFold(nn.Module):
msa_mask
=
msa_mask
.
to
(
dtype
=
input_tensors
[
0
].
dtype
),
msa_mask
=
msa_mask
.
to
(
dtype
=
input_tensors
[
0
].
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
input_tensors
[
1
].
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
input_tensors
[
1
].
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_deepspeed_evo_attention
=
self
.
globals
.
use_deepspeed_evo_attention
,
use_lma
=
self
.
globals
.
use_lma
,
use_lma
=
self
.
globals
.
use_lma
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
...
@@ -397,6 +400,7 @@ class AlphaFold(nn.Module):
...
@@ -397,6 +400,7 @@ class AlphaFold(nn.Module):
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
),
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_deepspeed_evo_attention
=
self
.
globals
.
use_deepspeed_evo_attention
,
use_lma
=
self
.
globals
.
use_lma
,
use_lma
=
self
.
globals
.
use_lma
,
use_flash
=
self
.
globals
.
use_flash
,
use_flash
=
self
.
globals
.
use_flash
,
inplace_safe
=
inplace_safe
,
inplace_safe
=
inplace_safe
,
...
...
openfold/model/msa.py
View file @
f0a320e0
...
@@ -91,7 +91,8 @@ class MSAAttention(nn.Module):
...
@@ -91,7 +91,8 @@ class MSAAttention(nn.Module):
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
biases
:
Optional
[
List
[
torch
.
Tensor
]],
biases
:
Optional
[
List
[
torch
.
Tensor
]],
chunk_size
:
int
,
chunk_size
:
int
,
use_memory_efficient_kernel
:
bool
,
use_memory_efficient_kernel
:
bool
,
use_deepspeed_evo_attention
:
bool
,
use_lma
:
bool
,
use_lma
:
bool
,
use_flash
:
bool
,
use_flash
:
bool
,
flash_mask
:
Optional
[
torch
.
Tensor
],
flash_mask
:
Optional
[
torch
.
Tensor
],
...
@@ -103,6 +104,7 @@ class MSAAttention(nn.Module):
...
@@ -103,6 +104,7 @@ class MSAAttention(nn.Module):
kv_x
=
m
,
kv_x
=
m
,
biases
=
biases
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
use_flash
=
use_flash
,
flash_mask
=
flash_mask
,
flash_mask
=
flash_mask
,
...
@@ -221,6 +223,7 @@ class MSAAttention(nn.Module):
...
@@ -221,6 +223,7 @@ class MSAAttention(nn.Module):
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_memory_efficient_kernel
:
bool
=
False
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
use_flash
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
...
@@ -267,7 +270,8 @@ class MSAAttention(nn.Module):
...
@@ -267,7 +270,8 @@ class MSAAttention(nn.Module):
m
,
m
,
biases
,
biases
,
chunk_size
,
chunk_size
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
use_flash
=
use_flash
,
flash_mask
=
mask
,
flash_mask
=
mask
,
...
@@ -279,6 +283,7 @@ class MSAAttention(nn.Module):
...
@@ -279,6 +283,7 @@ class MSAAttention(nn.Module):
kv_x
=
m
,
kv_x
=
m
,
biases
=
biases
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
use_flash
=
use_flash
,
flash_mask
=
mask
,
flash_mask
=
mask
,
...
@@ -356,6 +361,7 @@ class MSAColumnAttention(nn.Module):
...
@@ -356,6 +361,7 @@ class MSAColumnAttention(nn.Module):
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_flash
:
bool
=
False
,
use_flash
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -378,7 +384,8 @@ class MSAColumnAttention(nn.Module):
...
@@ -378,7 +384,8 @@ class MSAColumnAttention(nn.Module):
m
=
self
.
_msa_att
(
m
=
self
.
_msa_att
(
m
,
m
,
mask
=
mask
,
mask
=
mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
use_flash
=
use_flash
,
use_flash
=
use_flash
,
)
)
...
...
openfold/model/primitives.py
View file @
f0a320e0
...
@@ -12,20 +12,19 @@
...
@@ -12,20 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partial
import
importlib
import
importlib
import
math
import
math
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
import
numpy
as
np
import
numpy
as
np
deepspeed_is_installed
=
importlib
.
util
.
find_spec
(
"deepspeed"
)
is
not
None
deepspeed_is_installed
=
importlib
.
util
.
find_spec
(
"deepspeed"
)
is
not
None
if
(
deepspeed_is_installed
)
:
if
deepspeed_is_installed
:
import
deepspeed
import
deepspeed
from
deepspeed.ops.deepspeed4science
import
DS4Sci_EvoformerAttention
fa_is_installed
=
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
not
None
fa_is_installed
=
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
not
None
if
(
fa_is_installed
):
if
fa_is_installed
:
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.bert_padding
import
unpad_input
from
flash_attn.flash_attention
import
FlashAttention
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_kvpacked_func
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_kvpacked_func
import
torch
import
torch
...
@@ -33,7 +32,6 @@ import torch.nn as nn
...
@@ -33,7 +32,6 @@ import torch.nn as nn
from
scipy.stats
import
truncnorm
from
scipy.stats
import
truncnorm
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.chunk_utils
import
_chunk_slice
from
openfold.utils.kernel.attention_core
import
attention_core
from
openfold.utils.kernel.attention_core
import
attention_core
from
openfold.utils.precision_utils
import
is_fp16_enabled
from
openfold.utils.precision_utils
import
is_fp16_enabled
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
...
@@ -42,8 +40,8 @@ from openfold.utils.tensor_utils import (
...
@@ -42,8 +40,8 @@ from openfold.utils.tensor_utils import (
)
)
DEFAULT_LMA_Q_CHUNK_SIZE
=
1024
DEFAULT_LMA_Q_CHUNK_SIZE
=
1024
DEFAULT_LMA_KV_CHUNK_SIZE
=
4096
DEFAULT_LMA_KV_CHUNK_SIZE
=
4096
def
_prod
(
nums
):
def
_prod
(
nums
):
...
@@ -196,9 +194,9 @@ class LayerNorm(nn.Module):
...
@@ -196,9 +194,9 @@ class LayerNorm(nn.Module):
d
=
x
.
dtype
d
=
x
.
dtype
deepspeed_is_initialized
=
(
deepspeed_is_initialized
=
(
deepspeed_is_installed
and
deepspeed_is_installed
and
deepspeed
.
utils
.
is_initialized
()
deepspeed
.
comm
.
comm
.
is_initialized
()
)
)
if
(
d
is
torch
.
bfloat16
and
not
deepspeed_is_initialized
)
:
if
d
is
torch
.
bfloat16
and
not
deepspeed_is_initialized
:
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
out
=
nn
.
functional
.
layer_norm
(
out
=
nn
.
functional
.
layer_norm
(
x
,
x
,
...
@@ -228,9 +226,9 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
...
@@ -228,9 +226,9 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
d
=
t
.
dtype
d
=
t
.
dtype
deepspeed_is_initialized
=
(
deepspeed_is_initialized
=
(
deepspeed_is_installed
and
deepspeed_is_installed
and
deepspeed
.
utils
.
is_initialized
()
deepspeed
.
comm
.
comm
.
is_initialized
()
)
)
if
(
d
is
torch
.
bfloat16
and
not
deepspeed_is_initialized
)
:
if
d
is
torch
.
bfloat16
and
not
deepspeed_is_initialized
:
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
else
:
else
:
...
@@ -262,7 +260,7 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias
...
@@ -262,7 +260,7 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias
def
_attention_chunked_trainable
(
def
_attention_chunked_trainable
(
query
,
key
,
value
,
biases
,
chunk_size
,
chunk_dim
,
checkpoint
,
query
,
key
,
value
,
biases
,
chunk_size
,
chunk_dim
,
checkpoint
,
):
):
if
(
checkpoint
and
len
(
biases
)
>
2
)
:
if
checkpoint
and
len
(
biases
)
>
2
:
raise
ValueError
(
raise
ValueError
(
"Checkpointed version permits only permits two bias terms"
"Checkpointed version permits only permits two bias terms"
)
)
...
@@ -290,7 +288,7 @@ def _attention_chunked_trainable(
...
@@ -290,7 +288,7 @@ def _attention_chunked_trainable(
)
)
return
b
[
tuple
(
idx
)]
return
b
[
tuple
(
idx
)]
if
(
checkpoint
)
:
if
checkpoint
:
bias_1_chunk
,
bias_2_chunk
=
[
bias_1_chunk
,
bias_2_chunk
=
[
_slice_bias
(
b
)
if
b
is
not
None
else
None
_slice_bias
(
b
)
if
b
is
not
None
else
None
for
b
in
(
biases
+
[
None
,
None
])[:
2
]
for
b
in
(
biases
+
[
None
,
None
])[:
2
]
...
@@ -404,7 +402,7 @@ class Attention(nn.Module):
...
@@ -404,7 +402,7 @@ class Attention(nn.Module):
o
:
torch
.
Tensor
,
o
:
torch
.
Tensor
,
q_x
:
torch
.
Tensor
q_x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
(
self
.
linear_g
is
not
None
)
:
if
self
.
linear_g
is
not
None
:
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
# [*, Q, H, C_hidden]
# [*, Q, H, C_hidden]
...
@@ -425,11 +423,12 @@ class Attention(nn.Module):
...
@@ -425,11 +423,12 @@ class Attention(nn.Module):
kv_x
:
torch
.
Tensor
,
kv_x
:
torch
.
Tensor
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_memory_efficient_kernel
:
bool
=
False
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
lma_q_chunk_size
:
int
=
DEFAULT_LMA_Q_CHUNK_SIZE
,
lma_q_chunk_size
:
int
=
DEFAULT_LMA_Q_CHUNK_SIZE
,
lma_kv_chunk_size
:
int
=
DEFAULT_LMA_KV_CHUNK_SIZE
,
lma_kv_chunk_size
:
int
=
DEFAULT_LMA_KV_CHUNK_SIZE
,
use_flash
:
bool
=
False
,
use_flash
:
bool
=
False
,
flash_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
flash_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -444,6 +443,10 @@ class Attention(nn.Module):
...
@@ -444,6 +443,10 @@ class Attention(nn.Module):
This should be the default choice for most. If none of the
This should be the default choice for most. If none of the
"use_<...>" flags are True, a stock PyTorch implementation
"use_<...>" flags are True, a stock PyTorch implementation
is used instead
is used instead
use_deepspeed_evo_attention:
Whether to use DeepSpeed memory-efficient attention kernel.
If none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
use_lma:
use_lma:
Whether to use low-memory attention (Staats & Rabe 2021). If
Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch
none of the "use_<...>" flags are True, a stock PyTorch
...
@@ -455,25 +458,25 @@ class Attention(nn.Module):
...
@@ -455,25 +458,25 @@ class Attention(nn.Module):
Returns
Returns
[*, Q, C_q] attention update
[*, Q, C_q] attention update
"""
"""
if
(
use_lma
and
(
lma_q_chunk_size
is
None
or
lma_kv_chunk_size
is
None
)
)
:
if
use_lma
and
(
lma_q_chunk_size
is
None
or
lma_kv_chunk_size
is
None
):
raise
ValueError
(
raise
ValueError
(
"If use_lma is specified, lma_q_chunk_size and "
"If use_lma is specified, lma_q_chunk_size and "
"lma_kv_chunk_size must be provided"
"lma_kv_chunk_size must be provided"
)
)
if
(
use_flash
and
biases
is
not
None
)
:
if
use_flash
and
biases
is
not
None
:
raise
ValueError
(
raise
ValueError
(
"use_flash is incompatible with the bias option. For masking, "
"use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead"
"use flash_mask instead"
)
)
attn_options
=
[
use_memory_efficient_kernel
,
use_lma
,
use_flash
]
attn_options
=
[
use_memory_efficient_kernel
,
use_deepspeed_evo_attention
,
use_lma
,
use_flash
]
if
(
sum
(
attn_options
)
>
1
)
:
if
sum
(
attn_options
)
>
1
:
raise
ValueError
(
raise
ValueError
(
"Choose at most one alternative attention algorithm"
"Choose at most one alternative attention algorithm"
)
)
if
(
biases
is
None
)
:
if
biases
is
None
:
biases
=
[]
biases
=
[]
# [*, H, Q/K, C_hidden]
# [*, H, Q/K, C_hidden]
...
@@ -483,22 +486,47 @@ class Attention(nn.Module):
...
@@ -483,22 +486,47 @@ class Attention(nn.Module):
if
is_fp16_enabled
():
if
is_fp16_enabled
():
use_memory_efficient_kernel
=
False
use_memory_efficient_kernel
=
False
if
(
use_memory_efficient_kernel
)
:
if
use_memory_efficient_kernel
:
if
(
len
(
biases
)
>
2
)
:
if
len
(
biases
)
>
2
:
raise
ValueError
(
raise
ValueError
(
"If use_memory_efficient_kernel is True, you may only "
"If use_memory_efficient_kernel is True, you may only "
"provide up to two bias terms"
"provide up to two bias terms"
)
)
o
=
attention_core
(
q
,
k
,
v
,
*
((
biases
+
[
None
]
*
2
)[:
2
]))
o
=
attention_core
(
q
,
k
,
v
,
*
((
biases
+
[
None
]
*
2
)[:
2
]))
o
=
o
.
transpose
(
-
2
,
-
3
)
o
=
o
.
transpose
(
-
2
,
-
3
)
elif
(
use_lma
):
elif
use_deepspeed_evo_attention
:
q
=
q
.
transpose
(
-
2
,
-
3
)
k
=
k
.
transpose
(
-
2
,
-
3
)
v
=
v
.
transpose
(
-
2
,
-
3
)
add_batch_dim
=
len
(
q
.
shape
)
<
5
if
add_batch_dim
:
q
=
q
.
unsqueeze
(
0
)
k
=
k
.
unsqueeze
(
0
)
v
=
v
.
unsqueeze
(
0
)
biases
=
[
b
.
unsqueeze
(
0
)
for
b
in
biases
]
orig_dtype
=
q
.
dtype
if
orig_dtype
not
in
[
torch
.
bfloat16
,
torch
.
float16
]:
o
=
DS4Sci_EvoformerAttention
(
q
.
to
(
dtype
=
torch
.
bfloat16
),
k
.
to
(
dtype
=
torch
.
bfloat16
),
v
.
to
(
dtype
=
torch
.
bfloat16
),
[
b
.
to
(
dtype
=
torch
.
bfloat16
)
for
b
in
biases
])
o
=
o
.
to
(
dtype
=
orig_dtype
)
else
:
o
=
DS4Sci_EvoformerAttention
(
q
,
k
,
v
,
biases
)
if
add_batch_dim
:
o
=
o
.
squeeze
(
0
)
elif
use_lma
:
biases
=
[
biases
=
[
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
for
b
in
biases
for
b
in
biases
]
]
o
=
_lma
(
q
,
k
,
v
,
biases
,
lma_q_chunk_size
,
lma_kv_chunk_size
)
o
=
_lma
(
q
,
k
,
v
,
biases
,
lma_q_chunk_size
,
lma_kv_chunk_size
)
o
=
o
.
transpose
(
-
2
,
-
3
)
o
=
o
.
transpose
(
-
2
,
-
3
)
elif
(
use_flash
)
:
elif
use_flash
:
o
=
_flash_attn
(
q
,
k
,
v
,
flash_mask
)
o
=
_flash_attn
(
q
,
k
,
v
,
flash_mask
)
else
:
else
:
o
=
_attention
(
q
,
k
,
v
,
biases
)
o
=
_attention
(
q
,
k
,
v
,
biases
)
...
@@ -556,7 +584,7 @@ class GlobalAttention(nn.Module):
...
@@ -556,7 +584,7 @@ class GlobalAttention(nn.Module):
v
=
self
.
linear_v
(
m
)
v
=
self
.
linear_v
(
m
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
if
(
not
use_lma
)
:
if
not
use_lma
:
# [*, N_res, H, N_seq]
# [*, N_res, H, N_seq]
a
=
torch
.
matmul
(
a
=
torch
.
matmul
(
q
,
q
,
...
@@ -662,7 +690,7 @@ def _lma(
...
@@ -662,7 +690,7 @@ def _lma(
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
def
_flash_attn
(
q
,
k
,
v
,
kv_mask
):
def
_flash_attn
(
q
,
k
,
v
,
kv_mask
):
if
(
not
fa_is_installed
)
:
if
not
fa_is_installed
:
raise
ValueError
(
raise
ValueError
(
"_flash_attn requires that FlashAttention be installed"
"_flash_attn requires that FlashAttention be installed"
)
)
...
@@ -714,8 +742,8 @@ def _flash_attn(q, k, v, kv_mask):
...
@@ -714,8 +742,8 @@ def _flash_attn(q, k, v, kv_mask):
kv_cu_seqlens
,
kv_cu_seqlens
,
q_max_s
,
q_max_s
,
kv_max_s
,
kv_max_s
,
dropout_p
=
0.
,
dropout_p
=
0.
,
softmax_scale
=
1.
,
# q has been scaled already
softmax_scale
=
1.
,
# q has been scaled already
)
)
# [*, B, N, H, C]
# [*, B, N, H, C]
...
...
openfold/model/triangular_attention.py
View file @
f0a320e0
...
@@ -63,6 +63,7 @@ class TriangleAttention(nn.Module):
...
@@ -63,6 +63,7 @@ class TriangleAttention(nn.Module):
biases
:
List
[
torch
.
Tensor
],
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
chunk_size
:
int
,
use_memory_efficient_kernel
:
bool
=
False
,
use_memory_efficient_kernel
:
bool
=
False
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -77,6 +78,7 @@ class TriangleAttention(nn.Module):
...
@@ -77,6 +78,7 @@ class TriangleAttention(nn.Module):
partial
(
partial
(
self
.
mha
,
self
.
mha
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
use_lma
=
use_lma
),
),
mha_inputs
,
mha_inputs
,
...
@@ -90,6 +92,7 @@ class TriangleAttention(nn.Module):
...
@@ -90,6 +92,7 @@ class TriangleAttention(nn.Module):
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_memory_efficient_kernel
:
bool
=
False
,
use_deepspeed_evo_attention
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -130,6 +133,7 @@ class TriangleAttention(nn.Module):
...
@@ -130,6 +133,7 @@ class TriangleAttention(nn.Module):
biases
,
biases
,
chunk_size
,
chunk_size
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
inplace_safe
=
inplace_safe
,
)
)
...
@@ -139,6 +143,7 @@ class TriangleAttention(nn.Module):
...
@@ -139,6 +143,7 @@ class TriangleAttention(nn.Module):
kv_x
=
x
,
kv_x
=
x
,
biases
=
biases
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_deepspeed_evo_attention
=
use_deepspeed_evo_attention
,
use_lma
=
use_lma
use_lma
=
use_lma
)
)
...
...
tests/config.py
View file @
f0a320e0
...
@@ -3,7 +3,7 @@ import ml_collections as mlc
...
@@ -3,7 +3,7 @@ import ml_collections as mlc
consts
=
mlc
.
ConfigDict
(
consts
=
mlc
.
ConfigDict
(
{
{
"batch_size"
:
2
,
"batch_size"
:
2
,
"n_res"
:
11
,
"n_res"
:
20
,
"n_seq"
:
13
,
"n_seq"
:
13
,
"n_templ"
:
3
,
"n_templ"
:
3
,
"n_extra"
:
17
,
"n_extra"
:
17
,
...
...
tests/test_deepspeed_evo_attention.py
0 → 100644
View file @
f0a320e0
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
unittest
import
numpy
as
np
import
pickle
from
openfold.model.primitives
import
(
Attention
,
)
from
tests.config
import
consts
import
tests.compare_utils
as
compare_utils
from
openfold.data
import
data_transforms
from
openfold.utils.tensor_utils
import
tensor_tree_map
class
TestDeepSpeedKernel
(
unittest
.
TestCase
):
def
test_ds_kernel_vs_attention
(
self
):
batch_size
=
consts
.
batch_size
c_hidden
=
32
n
=
2
**
12
n_seq
=
12
no_heads
=
4
q
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
c_hidden
).
cuda
()
kv
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
c_hidden
).
cuda
()
bias
=
[
torch
.
rand
(
batch_size
,
n_seq
,
1
,
1
,
n
),
torch
.
rand
(
batch_size
,
1
,
no_heads
,
n
,
n
)]
bias
=
[
b
.
cuda
()
for
b
in
bias
]
a
=
Attention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
with
torch
.
no_grad
():
l
=
a
(
q
,
kv
,
biases
=
bias
,
use_deepspeed_evo_attention
=
True
)
real
=
a
(
q
,
kv
,
biases
=
bias
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
consts
.
eps
)
def
compare_evoformer
(
self
,
dtype
):
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
activations
=
{
"msa"
:
torch
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
,
device
=
'cuda'
,
dtype
=
dtype
),
"pair"
:
torch
.
rand
(
n_res
,
n_res
,
consts
.
c_z
,
device
=
'cuda'
,
dtype
=
dtype
)
}
masks
=
{
"msa"
:
torch
.
randint
(
0
,
2
,
(
n_seq
,
n_res
),
device
=
'cuda'
,
dtype
=
dtype
),
"pair"
:
torch
.
randint
(
0
,
2
,
(
n_res
,
n_res
),
device
=
'cuda'
,
dtype
=
dtype
),
}
with
torch
.
cuda
.
amp
.
autocast
(
dtype
=
dtype
):
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro_msa
,
out_repro_pair
=
model
.
evoformer
.
blocks
[
0
](
activations
[
"msa"
],
activations
[
"pair"
],
masks
[
"msa"
],
masks
[
"pair"
],
use_deepspeed_evo_attention
=
False
,
chunk_size
=
4
,
_mask_trans
=
False
,
inplace_safe
=
False
,
)
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
out_repro_msa_ds
,
out_repro_pair_ds
=
model
.
evoformer
.
blocks
[
0
](
activations
[
"msa"
],
activations
[
"pair"
],
masks
[
"msa"
],
masks
[
"pair"
],
use_deepspeed_evo_attention
=
True
,
chunk_size
=
4
,
_mask_trans
=
False
,
inplace_safe
=
False
,
)
out_repro_msa_ds
=
out_repro_msa_ds
.
cpu
()
out_repro_pair_ds
=
out_repro_pair_ds
.
cpu
()
self
.
assertTrue
(
torch
.
allclose
(
torch
.
abs
(
out_repro_msa
),
torch
.
abs
(
out_repro_msa_ds
),
atol
=
consts
.
eps
))
self
.
assertTrue
(
torch
.
allclose
(
torch
.
abs
(
out_repro_pair
),
torch
.
abs
(
out_repro_pair_ds
),
atol
=
consts
.
eps
))
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare_evoformer_bf16
(
self
):
self
.
compare_evoformer
(
torch
.
bfloat16
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare_evoformer_fp32
(
self
):
self
.
compare_evoformer
(
torch
.
float32
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_dry_run
(
self
):
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
batch
=
pickle
.
load
(
fp
)
# atom37_to_atom14 doesn't like batches
batch
[
"residx_atom14_to_atom37"
]
=
batch
[
"residx_atom14_to_atom37"
][
0
]
batch
[
"atom14_atom_exists"
]
=
batch
[
"atom14_atom_exists"
][
0
]
batch
[
"no_recycling_iters"
]
=
np
.
array
([
3.
,
3.
,
3.
,
3.
,
])
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
batch
[
"aatype"
]
=
batch
[
"aatype"
].
long
()
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
long
()
batch
[
"extra_msa"
]
=
batch
[
"extra_msa"
].
long
()
batch
[
"residx_atom37_to_atom14"
]
=
batch
[
"residx_atom37_to_atom14"
].
long
()
batch
[
"template_all_atom_mask"
]
=
batch
[
"template_all_atom_masks"
]
batch
.
update
(
data_transforms
.
atom37_to_torsion_angles
(
"template_"
)(
batch
)
)
# Move the recycling dimension to the end
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
)
batch
=
tensor_tree_map
(
move_dim
,
batch
)
with
torch
.
no_grad
():
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
.
globals
.
use_deepspeed_evo_attention
=
True
out_repro
=
model
(
batch
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment