Commit 1e302221 authored by zhuwenwen's avatar zhuwenwen
Browse files

update test_flash_mla

parent d2e73fb4
......@@ -16,7 +16,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
x, y = x.double(), y.double()
cos_diff = 1 - 2 * (x * y).sum().item() / max(
(x * x + y * y).sum().item(), 1e-12)
assert cos_diff < 1e-5
assert cos_diff < 1e-4
FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
if not is_flashmla_supported()[0] else "FlashMLA is supported"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment