"container/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "a0e1da037c17c4c5a7f990ab09f54d0a8f446994"
Unverified Commit 9e057b7a authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #396 from aqlaboratory/deepspeed-mean-test

Change test_compare_model in deepspeed test to use mean instead of max
parents 561333f2 39f9958b
...@@ -274,7 +274,7 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -274,7 +274,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
Run full model with and without using DeepSpeed Evoformer attention kernel Run full model with and without using DeepSpeed Evoformer attention kernel
and compare output coordinates. and compare output coordinates.
""" """
eps = 0.5 eps = 0.2
with open("tests/test_data/sample_feats.pickle", "rb") as fp: with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp) batch = pickle.load(fp)
...@@ -316,7 +316,7 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -316,7 +316,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1].squeeze(0) out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0) out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0)
err = torch.max(torch.abs(out_repro - out_repro_ds)) err = torch.mean(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error: {err}') self.assertTrue(err < eps, f'Error: {err}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment