Commit a8d896fd authored by Christina Floristean's avatar Christina Floristean
Browse files

Fix seq min length issue in kernel test

parent e9898a60
......@@ -28,6 +28,6 @@ dependencies:
- wandb==0.12.21
- modelcif==0.7
- git+https://github.com/NVIDIA/dllogger.git
- git+https://github.com/microsoft/DeepSpeed.git
- git+https://github.com/cctry/DeepSpeed.git
# TODO: Replace above when version becomes available
# - deepspeed==0.10.4
......@@ -3,7 +3,7 @@ import ml_collections as mlc
consts = mlc.ConfigDict(
{
"batch_size": 2,
"n_res": 20,
"n_res": 11,
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
......
......@@ -27,14 +27,8 @@ from openfold.model.primitives import (
)
from tests.config import consts
import tests.compare_utils as compare_utils
from tests.data_utils import (
random_template_feats,
random_extra_msa_feats,
)
from openfold.config import model_config
from openfold.data import data_transforms
from openfold.model.model import AlphaFold
from openfold.utils.tensor_utils import tensor_tree_map
......@@ -70,8 +64,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
Set dtype to confirm the kernel can be used during both training (BF16) and inference (FP32),
since the kernel itself can run with either BF16 or FP16 precision.
"""
n_res = consts.n_res
n_seq = consts.n_seq
n_res = 20
n_seq = 18
eps = 2e-2
activations = {
......@@ -156,20 +150,22 @@ class TestDeepSpeedKernel(unittest.TestCase):
batch = tensor_tree_map(move_dim, batch)
with torch.no_grad():
model = compare_utils.get_global_pretrained_openfold()
out_repro = model(batch)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
model = compare_utils.get_global_pretrained_openfold()
out_repro = model(batch)
# Enable kernel
model.globals.use_deepspeed_evo_attention = True
out_repro_ds = model(batch)
# Enable kernel
model.globals.use_deepspeed_evo_attention = True
out_repro_ds = model(batch)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
out_repro_ds = tensor_tree_map(lambda t: t.cpu(), out_repro_ds)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
out_repro_ds = tensor_tree_map(lambda t: t.cpu(), out_repro_ds)
out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0)
out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0)
self.assertTrue(torch.max(torch.abs(out_repro - out_repro_ds)) < eps)
err = torch.max(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error: {err}')
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment