Commit 37d0dead authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

Fix assertion in GQA backward example to ensure correct tensor comparison for...

Fix assertion in GQA backward example to ensure correct tensor comparison for gradient validation (#568)
parent f70c1e65
...@@ -342,7 +342,7 @@ def main(BATCH: int = 8, ...@@ -342,7 +342,7 @@ def main(BATCH: int = 8,
assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
# assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment