Commit f1563999 authored by Christina Floristean's avatar Christina Floristean
Browse files

Minor test fix

parent 7fb12cf5
......@@ -145,7 +145,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
err = torch.max(torch.abs(t_repro.grad.cpu() - t_gt.grad.cpu()))
self.assertTrue(err < eps, f'Error item #{i}: {err}')
def compare_evoformer(self, dtype):
def compare_evoformer(self, dtype, eps):
"""
Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel.
Set dtype to confirm the kernel can be used during both training (BF16) and inference (FP32),
......@@ -155,7 +155,6 @@ class TestDeepSpeedKernel(unittest.TestCase):
n_seq = 18
c_m_shape = (consts.c_m,)
c_z_shape = (consts.c_z,)
eps = 5e-2
activations = {
"msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
......@@ -206,11 +205,11 @@ class TestDeepSpeedKernel(unittest.TestCase):
def test_compare_evoformer_bf16(self):
"""Run evoformer comparison test with BF16 precision."""
self.compare_evoformer(torch.bfloat16)
self.compare_evoformer(dtype=torch.bfloat16, eps=4e-2)
def test_compare_evoformer_fp32(self):
"""Run evoformer comparison test with FP32 precision."""
self.compare_evoformer(torch.float32)
self.compare_evoformer(dtype=torch.float32, eps=2e-2)
def test_compare_template_stack(self):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment