Commit f1563999 authored by Christina Floristean's avatar Christina Floristean
Browse files

Minor test fix

parent 7fb12cf5
...@@ -145,7 +145,7 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -145,7 +145,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
err = torch.max(torch.abs(t_repro.grad.cpu() - t_gt.grad.cpu())) err = torch.max(torch.abs(t_repro.grad.cpu() - t_gt.grad.cpu()))
self.assertTrue(err < eps, f'Error item #{i}: {err}') self.assertTrue(err < eps, f'Error item #{i}: {err}')
def compare_evoformer(self, dtype): def compare_evoformer(self, dtype, eps):
""" """
Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel. Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel.
Set dtype to confirm the kernel can be used during both training (BF16) and inference (FP32), Set dtype to confirm the kernel can be used during both training (BF16) and inference (FP32),
...@@ -155,7 +155,6 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -155,7 +155,6 @@ class TestDeepSpeedKernel(unittest.TestCase):
n_seq = 18 n_seq = 18
c_m_shape = (consts.c_m,) c_m_shape = (consts.c_m,)
c_z_shape = (consts.c_z,) c_z_shape = (consts.c_z,)
eps = 5e-2
activations = { activations = {
"msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype), "msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
...@@ -206,11 +205,11 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -206,11 +205,11 @@ class TestDeepSpeedKernel(unittest.TestCase):
def test_compare_evoformer_bf16(self): def test_compare_evoformer_bf16(self):
"""Run evoformer comparison test with BF16 precision.""" """Run evoformer comparison test with BF16 precision."""
self.compare_evoformer(torch.bfloat16) self.compare_evoformer(dtype=torch.bfloat16, eps=4e-2)
def test_compare_evoformer_fp32(self): def test_compare_evoformer_fp32(self):
"""Run evoformer comparison test with FP32 precision.""" """Run evoformer comparison test with FP32 precision."""
self.compare_evoformer(torch.float32) self.compare_evoformer(dtype=torch.float32, eps=2e-2)
def test_compare_template_stack(self): def test_compare_template_stack(self):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment