Commit 818db4ae authored by Zimin Li's avatar Zimin Li
Browse files

issue/273: fully support equal_nan option for debug() and debug_all()

parent 7c593b7a
...@@ -224,7 +224,7 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): ...@@ -224,7 +224,7 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
""" """
import numpy as np import numpy as np
print_discrepancy(actual, desired, atol, rtol, verbose) print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
np.testing.assert_allclose( np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
) )
...@@ -270,7 +270,7 @@ def debug_all( ...@@ -270,7 +270,7 @@ def debug_all(
for index, (actual, desired) in enumerate(zip(actual_vals, desired_vals)): for index, (actual, desired) in enumerate(zip(actual_vals, desired_vals)):
print(f" \033[36mCondition #{index + 1}:\033[0m {actual} == {desired}") print(f" \033[36mCondition #{index + 1}:\033[0m {actual} == {desired}")
indices = print_discrepancy(actual, desired, atol, rtol, verbose) indices = print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
if condition == "or": if condition == "or":
if not passed and len(indices) == 0: if not passed and len(indices) == 0:
passed = True passed = True
...@@ -292,7 +292,7 @@ def debug_all( ...@@ -292,7 +292,7 @@ def debug_all(
assert passed, "\033[31mThe condition has not been satisfied\033[0m" assert passed, "\033[31mThe condition has not been satisfied\033[0m"
def print_discrepancy(actual, expected, atol=0, rtol=1e-3, verbose=True): def print_discrepancy(actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True):
if actual.shape != expected.shape: if actual.shape != expected.shape:
raise ValueError("Tensors must have the same shape to compare.") raise ValueError("Tensors must have the same shape to compare.")
...@@ -301,8 +301,12 @@ def print_discrepancy(actual, expected, atol=0, rtol=1e-3, verbose=True): ...@@ -301,8 +301,12 @@ def print_discrepancy(actual, expected, atol=0, rtol=1e-3, verbose=True):
is_terminal = sys.stdout.isatty() is_terminal = sys.stdout.isatty()
actual_isnan = torch.isnan(actual)
expected_isnan = torch.isnan(expected)
# Calculate the difference mask based on atol and rtol # Calculate the difference mask based on atol and rtol
diff_mask = torch.abs(actual - expected) > (atol + rtol * torch.abs(expected)) nan_mismatch = actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
diff_mask = nan_mismatch | (torch.abs(actual - expected) > (atol + rtol * torch.abs(expected)))
diff_indices = torch.nonzero(diff_mask, as_tuple=False) diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected delta = actual - expected
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment