Unverified Commit 8a22f194 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #274 from InfiniTensor/issue/273_fix_python_test_debug

issue/273: Fully Support `equal_nan` Option for `debug()` and `debug_all()`
parents 7c593b7a 818db4ae
......@@ -224,7 +224,7 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
import numpy as np
print_discrepancy(actual, desired, atol, rtol, verbose)
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
)
......@@ -270,7 +270,7 @@ def debug_all(
for index, (actual, desired) in enumerate(zip(actual_vals, desired_vals)):
print(f" \033[36mCondition #{index + 1}:\033[0m {actual} == {desired}")
indices = print_discrepancy(actual, desired, atol, rtol, verbose)
indices = print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
if condition == "or":
if not passed and len(indices) == 0:
passed = True
......@@ -292,7 +292,7 @@ def debug_all(
assert passed, "\033[31mThe condition has not been satisfied\033[0m"
def print_discrepancy(actual, expected, atol=0, rtol=1e-3, verbose=True):
def print_discrepancy(actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True):
if actual.shape != expected.shape:
raise ValueError("Tensors must have the same shape to compare.")
......@@ -301,8 +301,12 @@ def print_discrepancy(actual, expected, atol=0, rtol=1e-3, verbose=True):
is_terminal = sys.stdout.isatty()
actual_isnan = torch.isnan(actual)
expected_isnan = torch.isnan(expected)
# Calculate the difference mask based on atol and rtol
diff_mask = torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
nan_mismatch = actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
diff_mask = nan_mismatch | (torch.abs(actual - expected) > (atol + rtol * torch.abs(expected)))
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment