import torch def get_abs_err(y, x): x = x.to(torch.float32) y = y.to(torch.float32) return (x - y).flatten().abs().max().item() def get_err_ratio(y, x): x = x.to(torch.float32) y = y.to(torch.float32) err = (x - y).flatten().square().mean().sqrt().item() base = (x).flatten().square().mean().sqrt().item() return err / base def calculate_tensor_similarity(x, y, name="tensor"): """ Calculate similarity between two tensors using a normalized dot product metric. Unlike torch.testing.assert_close which uses absolute/relative tolerance based on element-wise differences, this function computes a global similarity score: sim = 2 * / (||x||^2 + ||y||^2) This metric is scale-invariant and measures the cosine-like similarity normalized by the magnitude of both tensors. It returns 1 for identical tensors and values closer to 0 for dissimilar ones. This is particularly useful for comparing tensors with varying magnitudes where relative errors matter more than absolute differences. Args: x: First tensor to compare y: Second tensor to compare name: Name of the tensor for logging purposes Returns: Similarity score in range [0, 1] where 1 means identical """ x, y = x.data.double(), y.data.double() denominator = (x * x + y * y).sum() if denominator == 0: print(f"\033[33mWARNING: {name} all zero\033[0m") return 1 sim = 2 * (x * y).sum() / denominator return sim def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): """ Assert that two tensors are similar using a global similarity metric. Key differences from torch.testing.assert_close: - torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers and requires all elements to satisfy the tolerance. - assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the normalized dot product. It's more robust to outliers and focuses on overall tensor similarity rather than element-wise precision. This is better suited for comparing large tensors where a few outlier elements shouldn't fail the test. Args: x: First tensor to compare y: Second tensor to compare eps: Maximum allowed difference (1 - similarity), default 1e-8 name: Name of the tensor for error messages raise_assert: Whether to raise assertion error on failure """ sim = calculate_tensor_similarity(x, y, name) diff = 1. - sim if not (0 <= diff <= eps): print( f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m" ) if raise_assert: assert False # noqa: B011