Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
2bf18520
"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "c56d0dea2c6f8a917e1fb4f98dddbd119558cbe7"
Commit
2bf18520
authored
Sep 13, 2023
by
Christina Floristean
Browse files
Clean up DS kernel integration and test, add cutlass to installation procedure
parent
a6703606
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
104 additions
and
32 deletions
+104
-32
environment.yml
environment.yml
+1
-3
openfold/config.py
openfold/config.py
+5
-2
openfold/model/evoformer.py
openfold/model/evoformer.py
+8
-2
openfold/model/primitives.py
openfold/model/primitives.py
+38
-20
openfold/utils/trace_utils.py
openfold/utils/trace_utils.py
+4
-0
scripts/install_third_party_dependencies.sh
scripts/install_third_party_dependencies.sh
+5
-0
tests/test_deepspeed_evo_attention.py
tests/test_deepspeed_evo_attention.py
+43
-5
No files found.
environment.yml
View file @
2bf18520
...
@@ -28,8 +28,6 @@ dependencies:
...
@@ -28,8 +28,6 @@ dependencies:
-
wandb==0.12.21
-
wandb==0.12.21
-
modelcif==0.7
-
modelcif==0.7
-
git+https://github.com/NVIDIA/dllogger.git
-
git+https://github.com/NVIDIA/dllogger.git
-
git+https://github.com/NVIDIA/cutlass.git
-
git+https://github.com/microsoft/DeepSpeed.git
-
git+https://github.com/microsoft/DeepSpeed.git
# TODO: Replace above when version becomes available
# TODO: Replace above when version becomes available
# - deepspeed==0.10.3
# - deepspeed==0.10.4
openfold/config.py
View file @
2bf18520
...
@@ -367,12 +367,15 @@ config = mlc.ConfigDict(
...
@@ -367,12 +367,15 @@ config = mlc.ConfigDict(
"globals"
:
{
"globals"
:
{
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"chunk_size"
:
chunk_size
,
# Use DeepSpeed memory-efficient attention kernel. Mutually
# exclusive with use_lma and use_flash.
"use_deepspeed_evo_attention"
:
False
,
"use_deepspeed_evo_attention"
:
False
,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
# Use Staats & Rabe's low-memory attention algorithm. Mutually
# exclusive with use_flash.
# exclusive with
use_deepspeed_evo_attention and
use_flash.
"use_lma"
:
False
,
"use_lma"
:
False
,
# Use FlashAttention in selected modules. Mutually exclusive with
# Use FlashAttention in selected modules. Mutually exclusive with
# use_lma. Doesn't work that well on long sequences (>1000 residues).
# use_deepspeed_evo_attention and use_lma. Doesn't work that well
# on long sequences (>1000 residues).
"use_flash"
:
False
,
"use_flash"
:
False
,
"offload_inference"
:
False
,
"offload_inference"
:
False
,
"c_z"
:
c_z
,
"c_z"
:
c_z
,
...
...
openfold/model/evoformer.py
View file @
2bf18520
...
@@ -801,10 +801,15 @@ class EvoformerStack(nn.Module):
...
@@ -801,10 +801,15 @@ class EvoformerStack(nn.Module):
chunk_size:
chunk_size:
Inference-time subbatch size. Acts as a minimum if
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
use_deepspeed_evo_attention:
Whether to use DeepSpeed memory efficient kernel.
Mutually exclusive with use_lma and use_flash.
use_lma:
Whether to use low-memory attention during inference.
Mutually exclusive with use_flash and use_deepspeed_evo_attention.
use_flash:
use_flash:
Whether to use FlashAttention where possible. Mutually
Whether to use FlashAttention where possible. Mutually
exclusive with use_lma.
exclusive with use_lma
and use_deepspeed_evo_attention
.
Returns:
Returns:
m:
m:
[*, N_seq, N_res, C_m] MSA embedding
[*, N_seq, N_res, C_m] MSA embedding
...
@@ -1000,6 +1005,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -1000,6 +1005,7 @@ class ExtraMSAStack(nn.Module):
z:
z:
[*, N_res, N_res, C_z] pair embedding
[*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules
chunk_size: Inference-time subbatch size for Evoformer modules
use_deepspeed_evo_attention: Whether to use DeepSpeed memory-efficient kernel
use_lma: Whether to use low-memory attention during inference
use_lma: Whether to use low-memory attention during inference
msa_mask:
msa_mask:
Optional [*, N_extra, N_res] MSA mask
Optional [*, N_extra, N_res] MSA mask
...
...
openfold/model/primitives.py
View file @
2bf18520
...
@@ -20,7 +20,9 @@ import numpy as np
...
@@ -20,7 +20,9 @@ import numpy as np
deepspeed_is_installed
=
importlib
.
util
.
find_spec
(
"deepspeed"
)
is
not
None
deepspeed_is_installed
=
importlib
.
util
.
find_spec
(
"deepspeed"
)
is
not
None
if
deepspeed_is_installed
:
if
deepspeed_is_installed
:
import
deepspeed
import
deepspeed
from
deepspeed.ops.deepspeed4science
import
DS4Sci_EvoformerAttention
if
importlib
.
util
.
find_spec
(
"deepspeed.ops.deepspeed4science"
)
is
not
None
:
from
deepspeed.ops.deepspeed4science
import
DS4Sci_EvoformerAttention
fa_is_installed
=
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
not
None
fa_is_installed
=
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
not
None
if
fa_is_installed
:
if
fa_is_installed
:
...
@@ -375,7 +377,8 @@ class Attention(nn.Module):
...
@@ -375,7 +377,8 @@ class Attention(nn.Module):
def
_prep_qkv
(
self
,
def
_prep_qkv
(
self
,
q_x
:
torch
.
Tensor
,
q_x
:
torch
.
Tensor
,
kv_x
:
torch
.
Tensor
kv_x
:
torch
.
Tensor
,
transpose_qkv_dims
:
bool
=
True
)
->
Tuple
[
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
]:
...
@@ -389,10 +392,11 @@ class Attention(nn.Module):
...
@@ -389,10 +392,11 @@ class Attention(nn.Module):
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, H, Q/K, C_hidden]
if
transpose_qkv_dims
:
q
=
q
.
transpose
(
-
2
,
-
3
)
# [*, H, Q/K, C_hidden]
k
=
k
.
transpose
(
-
2
,
-
3
)
q
=
q
.
transpose
(
-
2
,
-
3
)
v
=
v
.
transpose
(
-
2
,
-
3
)
k
=
k
.
transpose
(
-
2
,
-
3
)
v
=
v
.
transpose
(
-
2
,
-
3
)
q
/=
math
.
sqrt
(
self
.
c_hidden
)
q
/=
math
.
sqrt
(
self
.
c_hidden
)
...
@@ -479,10 +483,10 @@ class Attention(nn.Module):
...
@@ -479,10 +483,10 @@ class Attention(nn.Module):
if
biases
is
None
:
if
biases
is
None
:
biases
=
[]
biases
=
[]
# [*, H, Q/K, C_hidden]
# DeepSpeed attention kernel expects Q/K/V of shape [*, Q/K, H, C_hidden]
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
# All other attention modules expect Q/K/V of shape [*, H, Q/K, C_hidden]
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
,
transpose_qkv_dims
=
not
use_deepspeed_evo_attention
)
# [*, Q, H, C_hidden]
if
is_fp16_enabled
():
if
is_fp16_enabled
():
use_memory_efficient_kernel
=
False
use_memory_efficient_kernel
=
False
...
@@ -495,17 +499,32 @@ class Attention(nn.Module):
...
@@ -495,17 +499,32 @@ class Attention(nn.Module):
o
=
attention_core
(
q
,
k
,
v
,
*
((
biases
+
[
None
]
*
2
)[:
2
]))
o
=
attention_core
(
q
,
k
,
v
,
*
((
biases
+
[
None
]
*
2
)[:
2
]))
o
=
o
.
transpose
(
-
2
,
-
3
)
o
=
o
.
transpose
(
-
2
,
-
3
)
elif
use_deepspeed_evo_attention
:
elif
use_deepspeed_evo_attention
:
q
=
q
.
transpose
(
-
2
,
-
3
)
if
len
(
biases
)
>
2
:
k
=
k
.
transpose
(
-
2
,
-
3
)
raise
ValueError
(
v
=
v
.
transpose
(
-
2
,
-
3
)
"If use_deepspeed_evo_attention is True, you may only "
"provide up to two bias terms"
)
add_batch_dim
=
len
(
q
.
shape
)
<
5
orig_shape
=
q
.
shape
if
add_batch_dim
:
no_batch_dims
=
len
(
orig_shape
[:
-
3
])
q
=
q
.
unsqueeze
(
0
)
if
no_batch_dims
>
2
:
k
=
k
.
unsqueeze
(
0
)
raise
ValueError
(
v
=
v
.
unsqueeze
(
0
)
f
"Q is of shape
{
list
(
orig_shape
)
}
but must be "
biases
=
[
b
.
unsqueeze
(
0
)
for
b
in
biases
]
"of shape [B, N, Q/K, H, C_hidden] if "
"use_deepspeed_evo_attention is True."
)
# Bypass asserts for bias shapes in DS4Sci_EvoformerAttention()
# by adding batch and N_seq dims if needed.
if
no_batch_dims
<
2
:
addl_dims
=
(
1
,)
*
(
2
-
no_batch_dims
)
q
=
q
.
view
(
*
(
addl_dims
+
q
.
shape
))
k
=
k
.
view
(
*
(
addl_dims
+
k
.
shape
))
v
=
v
.
view
(
*
(
addl_dims
+
v
.
shape
))
biases
=
[
b
.
view
(
*
(
addl_dims
+
b
.
shape
))
for
b
in
biases
]
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype
=
q
.
dtype
orig_dtype
=
q
.
dtype
if
orig_dtype
not
in
[
torch
.
bfloat16
,
torch
.
float16
]:
if
orig_dtype
not
in
[
torch
.
bfloat16
,
torch
.
float16
]:
o
=
DS4Sci_EvoformerAttention
(
q
.
to
(
dtype
=
torch
.
bfloat16
),
o
=
DS4Sci_EvoformerAttention
(
q
.
to
(
dtype
=
torch
.
bfloat16
),
...
@@ -517,8 +536,7 @@ class Attention(nn.Module):
...
@@ -517,8 +536,7 @@ class Attention(nn.Module):
else
:
else
:
o
=
DS4Sci_EvoformerAttention
(
q
,
k
,
v
,
biases
)
o
=
DS4Sci_EvoformerAttention
(
q
,
k
,
v
,
biases
)
if
add_batch_dim
:
o
=
o
.
view
(
orig_shape
)
o
=
o
.
squeeze
(
0
)
elif
use_lma
:
elif
use_lma
:
biases
=
[
biases
=
[
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
...
...
openfold/utils/trace_utils.py
View file @
2bf18520
...
@@ -181,6 +181,7 @@ def trace_model_(model, sample_input):
...
@@ -181,6 +181,7 @@ def trace_model_(model, sample_input):
(
"mask"
,
msa_mask
),
(
"mask"
,
msa_mask
),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_attn_chunk_size
)),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_attn_chunk_size
)),
(
"use_memory_efficient_kernel"
,
torch
.
tensor
(
False
)),
(
"use_memory_efficient_kernel"
,
torch
.
tensor
(
False
)),
(
"use_deepspeed_evo_attention"
,
torch
.
tensor
(
model
.
globals
.
use_deepspeed_evo_attention
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
]
]
verify_arg_order
(
verify_arg_order
(
...
@@ -201,6 +202,7 @@ def trace_model_(model, sample_input):
...
@@ -201,6 +202,7 @@ def trace_model_(model, sample_input):
(
"m"
,
m
),
(
"m"
,
m
),
(
"mask"
,
msa_mask
),
(
"mask"
,
msa_mask
),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_chunk_size
)),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_chunk_size
)),
(
"use_deepspeed_evo_attention"
,
torch
.
tensor
(
model
.
globals
.
use_deepspeed_evo_attention
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
(
"use_flash"
,
torch
.
tensor
(
model
.
globals
.
use_flash
)),
(
"use_flash"
,
torch
.
tensor
(
model
.
globals
.
use_flash
)),
]
]
...
@@ -283,6 +285,7 @@ def trace_model_(model, sample_input):
...
@@ -283,6 +285,7 @@ def trace_model_(model, sample_input):
(
"mask"
,
pair_mask
.
float
()),
(
"mask"
,
pair_mask
.
float
()),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_attn_chunk_size
)),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_attn_chunk_size
)),
(
"use_memory_efficient_kernel"
,
torch
.
tensor
(
False
)),
(
"use_memory_efficient_kernel"
,
torch
.
tensor
(
False
)),
(
"use_deepspeed_evo_attention"
,
torch
.
tensor
(
model
.
globals
.
use_deepspeed_evo_attention
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
(
"inplace_safe"
,
torch
.
tensor
(
True
)),
(
"inplace_safe"
,
torch
.
tensor
(
True
)),
]
]
...
@@ -305,6 +308,7 @@ def trace_model_(model, sample_input):
...
@@ -305,6 +308,7 @@ def trace_model_(model, sample_input):
(
"mask"
,
pair_mask
.
transpose
(
-
1
,
-
2
).
float
()),
(
"mask"
,
pair_mask
.
transpose
(
-
1
,
-
2
).
float
()),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_attn_chunk_size
)),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_attn_chunk_size
)),
(
"use_memory_efficient_kernel"
,
torch
.
tensor
(
False
)),
(
"use_memory_efficient_kernel"
,
torch
.
tensor
(
False
)),
(
"use_deepspeed_evo_attention"
,
torch
.
tensor
(
model
.
globals
.
use_deepspeed_evo_attention
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
(
"inplace_safe"
,
torch
.
tensor
(
True
)),
(
"inplace_safe"
,
torch
.
tensor
(
True
)),
]
]
...
...
scripts/install_third_party_dependencies.sh
View file @
2bf18520
...
@@ -25,6 +25,11 @@ git checkout 5b838a8bef
...
@@ -25,6 +25,11 @@ git checkout 5b838a8bef
python3 setup.py
install
python3 setup.py
install
cd
$CUR_DIR
cd
$CUR_DIR
echo
"Attempting to download CUTLASS, required for Deepspeed Evoformer attention kernel"
git clone https://github.com/NVIDIA/cutlass.git
conda
env
config vars
set
CUTLASS_PATH
=
$PWD
/cutlass
source
scripts/activate_conda_env.sh
# Install DeepMind's OpenMM patch
# Install DeepMind's OpenMM patch
OPENFOLD_DIR
=
$PWD
OPENFOLD_DIR
=
$PWD
pushd
lib/conda/envs/
$ENV_NAME
/lib/python3.9/site-packages/
\
pushd
lib/conda/envs/
$ENV_NAME
/lib/python3.9/site-packages/
\
...
...
tests/test_deepspeed_evo_attention.py
View file @
2bf18520
...
@@ -12,6 +12,11 @@
...
@@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
Unit tests to compare components of OpenFold run with the DeepSpeed memory-efficient
attention kernel, DS4Sci_EvoformerAttention vs. a stock PyTorch attention implementation.
"""
import
torch
import
torch
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
...
@@ -22,17 +27,26 @@ from openfold.model.primitives import (
...
@@ -22,17 +27,26 @@ from openfold.model.primitives import (
)
)
from
tests.config
import
consts
from
tests.config
import
consts
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.data_utils
import
(
random_template_feats
,
random_extra_msa_feats
,
)
from
openfold.config
import
model_config
from
openfold.data
import
data_transforms
from
openfold.data
import
data_transforms
from
openfold.model.model
import
AlphaFold
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
class
TestDeepSpeedKernel
(
unittest
.
TestCase
):
class
TestDeepSpeedKernel
(
unittest
.
TestCase
):
def
test_ds_kernel_vs_attention
(
self
):
def
test_ds_kernel_vs_attention
(
self
):
"""Compare attention with and without using DeepSpeed Evoformer kernel."""
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
c_hidden
=
32
c_hidden
=
32
n
=
2
**
12
n
=
2
**
12
n_seq
=
12
n_seq
=
12
no_heads
=
4
no_heads
=
4
eps
=
2e-2
q
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
c_hidden
).
cuda
()
q
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
c_hidden
).
cuda
()
kv
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
c_hidden
).
cuda
()
kv
=
torch
.
rand
(
batch_size
,
n_seq
,
n
,
c_hidden
).
cuda
()
...
@@ -48,11 +62,17 @@ class TestDeepSpeedKernel(unittest.TestCase):
...
@@ -48,11 +62,17 @@ class TestDeepSpeedKernel(unittest.TestCase):
l
=
a
(
q
,
kv
,
biases
=
bias
,
use_deepspeed_evo_attention
=
True
)
l
=
a
(
q
,
kv
,
biases
=
bias
,
use_deepspeed_evo_attention
=
True
)
real
=
a
(
q
,
kv
,
biases
=
bias
)
real
=
a
(
q
,
kv
,
biases
=
bias
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
eps
)
def
compare_evoformer
(
self
,
dtype
):
def
compare_evoformer
(
self
,
dtype
):
"""
Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel.
Set dtype to confirm the kernel can be used during both training (BF16) and inference (FP32),
since the kernel itself can run with either BF16 or FP16 precision.
"""
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
eps
=
2e-2
activations
=
{
activations
=
{
"msa"
:
torch
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
,
device
=
'cuda'
,
dtype
=
dtype
),
"msa"
:
torch
.
rand
(
n_seq
,
n_res
,
consts
.
c_m
,
device
=
'cuda'
,
dtype
=
dtype
),
...
@@ -93,16 +113,23 @@ class TestDeepSpeedKernel(unittest.TestCase):
...
@@ -93,16 +113,23 @@ class TestDeepSpeedKernel(unittest.TestCase):
out_repro_msa_ds
=
out_repro_msa_ds
.
cpu
()
out_repro_msa_ds
=
out_repro_msa_ds
.
cpu
()
out_repro_pair_ds
=
out_repro_pair_ds
.
cpu
()
out_repro_pair_ds
=
out_repro_pair_ds
.
cpu
()
self
.
assertTrue
(
torch
.
allclose
(
torch
.
abs
(
out_repro_msa
),
torch
.
abs
(
out_repro_msa_ds
),
atol
=
consts
.
eps
))
self
.
assertTrue
(
torch
.
allclose
(
torch
.
abs
(
out_repro_msa
),
torch
.
abs
(
out_repro_msa_ds
),
atol
=
eps
))
self
.
assertTrue
(
torch
.
allclose
(
torch
.
abs
(
out_repro_pair
),
torch
.
abs
(
out_repro_pair_ds
),
atol
=
consts
.
eps
))
self
.
assertTrue
(
torch
.
allclose
(
torch
.
abs
(
out_repro_pair
),
torch
.
abs
(
out_repro_pair_ds
),
atol
=
eps
))
def
test_compare_evoformer_bf16
(
self
):
def
test_compare_evoformer_bf16
(
self
):
"""Run evoformer comparison test with BF16 precision."""
self
.
compare_evoformer
(
torch
.
bfloat16
)
self
.
compare_evoformer
(
torch
.
bfloat16
)
def
test_compare_evoformer_fp32
(
self
):
def
test_compare_evoformer_fp32
(
self
):
"""Run evoformer comparison test with FP32 precision."""
self
.
compare_evoformer
(
torch
.
float32
)
self
.
compare_evoformer
(
torch
.
float32
)
def
test_dry_run
(
self
):
def
test_compare_model
(
self
):
"""
Run full model with and without using DeepSpeed Evoformer attention kernel
and compare output coordinates
"""
eps
=
2e-2
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
with
open
(
"tests/test_data/sample_feats.pickle"
,
"rb"
)
as
fp
:
batch
=
pickle
.
load
(
fp
)
batch
=
pickle
.
load
(
fp
)
...
@@ -130,9 +157,20 @@ class TestDeepSpeedKernel(unittest.TestCase):
...
@@ -130,9 +157,20 @@ class TestDeepSpeedKernel(unittest.TestCase):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
.
globals
.
use_deepspeed_evo_attention
=
True
out_repro
=
model
(
batch
)
out_repro
=
model
(
batch
)
# Enable kernel
model
.
globals
.
use_deepspeed_evo_attention
=
False
out_repro_ds
=
model
(
batch
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
out_repro_ds
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro_ds
)
out_repro
=
out_repro
[
"sm"
][
"positions"
][
-
1
].
squeeze
(
0
)
out_repro_ds
=
out_repro_ds
[
"sm"
][
"positions"
][
-
1
].
squeeze
(
0
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_repro
-
out_repro_ds
))
<
eps
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment