Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
818db4ae
Commit
818db4ae
authored
Jun 20, 2025
by
Zimin Li
Browse files
issue/273: fully support equal_nan option for debug() and debug_all()
parent
7c593b7a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
4 deletions
+8
-4
test/infiniop/libinfiniop/utils.py
test/infiniop/libinfiniop/utils.py
+8
-4
No files found.
test/infiniop/libinfiniop/utils.py
View file @
818db4ae
...
@@ -224,7 +224,7 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
...
@@ -224,7 +224,7 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
"""
import
numpy
as
np
import
numpy
as
np
print_discrepancy
(
actual
,
desired
,
atol
,
rtol
,
verbose
)
print_discrepancy
(
actual
,
desired
,
atol
,
rtol
,
equal_nan
,
verbose
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
actual
.
cpu
(),
desired
.
cpu
(),
rtol
,
atol
,
equal_nan
,
verbose
=
True
actual
.
cpu
(),
desired
.
cpu
(),
rtol
,
atol
,
equal_nan
,
verbose
=
True
)
)
...
@@ -270,7 +270,7 @@ def debug_all(
...
@@ -270,7 +270,7 @@ def debug_all(
for
index
,
(
actual
,
desired
)
in
enumerate
(
zip
(
actual_vals
,
desired_vals
)):
for
index
,
(
actual
,
desired
)
in
enumerate
(
zip
(
actual_vals
,
desired_vals
)):
print
(
f
"
\033
[36mCondition #
{
index
+
1
}
:
\033
[0m
{
actual
}
==
{
desired
}
"
)
print
(
f
"
\033
[36mCondition #
{
index
+
1
}
:
\033
[0m
{
actual
}
==
{
desired
}
"
)
indices
=
print_discrepancy
(
actual
,
desired
,
atol
,
rtol
,
verbose
)
indices
=
print_discrepancy
(
actual
,
desired
,
atol
,
rtol
,
equal_nan
,
verbose
)
if
condition
==
"or"
:
if
condition
==
"or"
:
if
not
passed
and
len
(
indices
)
==
0
:
if
not
passed
and
len
(
indices
)
==
0
:
passed
=
True
passed
=
True
...
@@ -292,7 +292,7 @@ def debug_all(
...
@@ -292,7 +292,7 @@ def debug_all(
assert
passed
,
"
\033
[31mThe condition has not been satisfied
\033
[0m"
assert
passed
,
"
\033
[31mThe condition has not been satisfied
\033
[0m"
def
print_discrepancy
(
actual
,
expected
,
atol
=
0
,
rtol
=
1e-3
,
verbose
=
True
):
def
print_discrepancy
(
actual
,
expected
,
atol
=
0
,
rtol
=
1e-3
,
equal_nan
=
True
,
verbose
=
True
):
if
actual
.
shape
!=
expected
.
shape
:
if
actual
.
shape
!=
expected
.
shape
:
raise
ValueError
(
"Tensors must have the same shape to compare."
)
raise
ValueError
(
"Tensors must have the same shape to compare."
)
...
@@ -301,8 +301,12 @@ def print_discrepancy(actual, expected, atol=0, rtol=1e-3, verbose=True):
...
@@ -301,8 +301,12 @@ def print_discrepancy(actual, expected, atol=0, rtol=1e-3, verbose=True):
is_terminal
=
sys
.
stdout
.
isatty
()
is_terminal
=
sys
.
stdout
.
isatty
()
actual_isnan
=
torch
.
isnan
(
actual
)
expected_isnan
=
torch
.
isnan
(
expected
)
# Calculate the difference mask based on atol and rtol
# Calculate the difference mask based on atol and rtol
diff_mask
=
torch
.
abs
(
actual
-
expected
)
>
(
atol
+
rtol
*
torch
.
abs
(
expected
))
nan_mismatch
=
actual_isnan
^
expected_isnan
if
equal_nan
else
actual_isnan
|
expected_isnan
diff_mask
=
nan_mismatch
|
(
torch
.
abs
(
actual
-
expected
)
>
(
atol
+
rtol
*
torch
.
abs
(
expected
)))
diff_indices
=
torch
.
nonzero
(
diff_mask
,
as_tuple
=
False
)
diff_indices
=
torch
.
nonzero
(
diff_mask
,
as_tuple
=
False
)
delta
=
actual
-
expected
delta
=
actual
-
expected
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment