Unverified Commit d5c88afa authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Example] Add correctness assert into dsa example (#937)

parent 242cb457
......@@ -2,6 +2,7 @@
import torch
import tilelang
from tilelang import language as T
from utils import assert_tensors_similar
@tilelang.jit(
......@@ -253,6 +254,12 @@ def test_sparse_mla_fwd(B=1,
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices)
if SKV <= 4096:
# otherwise may cause out of memory
ref_out = ref_sparse_mla_fwd_interface(q, kv, indices)
assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out")
print("assert_tensors_similar passed")
def fn():
return sparse_mla_fwd_interface(q, kv, indices)
......@@ -270,4 +277,4 @@ def test_sparse_mla_fwd(B=1,
if __name__ == "__main__":
test_sparse_mla_fwd(
B=1, S=4096, SKV=32768, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16)
B=1, S=4096, SKV=4096, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16)
......@@ -251,25 +251,62 @@ def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1,
return ks, ke
def print_red_warning(message):
print(f"\033[31mWARNING: {message}\033[0m")
def calculate_tensor_similarity(x, y, name="tensor"):
"""
Calculate similarity between two tensors using a normalized dot product metric.
Unlike torch.testing.assert_close which uses absolute/relative tolerance based on
element-wise differences, this function computes a global similarity score:
sim = 2 * <x, y> / (||x||^2 + ||y||^2)
This metric is scale-invariant and measures the cosine-like similarity normalized
by the magnitude of both tensors. It returns 1 for identical tensors and values
closer to 0 for dissimilar ones. This is particularly useful for comparing tensors
with varying magnitudes where relative errors matter more than absolute differences.
Args:
x: First tensor to compare
y: Second tensor to compare
name: Name of the tensor for logging purposes
def calc_sim(x, y, name="tensor"):
Returns:
Similarity score in range [0, 1] where 1 means identical
"""
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
print(f"\033[33mWARNING: {name} all zero\033[0m")
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
sim = calc_sim(x, y, name)
def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
"""
Assert that two tensors are similar using a global similarity metric.
Key differences from torch.testing.assert_close:
- torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking
that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers
and requires all elements to satisfy the tolerance.
- assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the
normalized dot product. It's more robust to outliers and focuses on overall
tensor similarity rather than element-wise precision. This is better suited for
comparing large tensors where a few outlier elements shouldn't fail the test.
Args:
x: First tensor to compare
y: Second tensor to compare
eps: Maximum allowed difference (1 - similarity), default 1e-8
name: Name of the tensor for error messages
raise_assert: Whether to raise assertion error on failure
"""
sim = calculate_tensor_similarity(x, y, name)
diff = 1. - sim
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff}')
print(
f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m"
)
if raise_assert:
assert False # noqa: B011
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment