Commit 7a55b415 authored by wooway777's avatar wooway777
Browse files

issue/757 - support equal_nan in test debug

parent f53b8435
......@@ -30,8 +30,10 @@ class TestConfig:
num_prerun=10,
num_iterations=1000,
verbose=False,
equal_nan=False,
):
self.debug = debug
self.equal_nan = equal_nan
self.bench = bench
self.num_prerun = num_prerun
self.num_iterations = num_iterations
......@@ -540,7 +542,11 @@ class BaseOperatorTest(ABC):
rtol = test_case.tolerance.get("rtol", 1e-3)
compare_fn = create_test_comparator(
config, atol, rtol, f"{test_case.description} - output_{i}"
config,
atol,
rtol,
f"{test_case.description} - output_{i}",
equal_nan=config.equal_nan,
)
is_valid = compare_fn(infini_out, torch_out)
......@@ -589,7 +595,11 @@ class BaseOperatorTest(ABC):
rtol = test_case.tolerance.get("rtol", 1e-3)
compare_fn = create_test_comparator(
config, atol, rtol, test_case.description
config,
atol,
rtol,
test_case.description,
equal_nan=config.equal_nan,
)
is_valid = compare_fn(infini_comparison, torch_comparison)
......
......@@ -44,6 +44,7 @@ def get_hardware_args_group(parser):
return hardware_group
def add_common_test_args(parser: argparse.ArgumentParser):
"""
Adds common test/execution arguments to the passed parser object.
......@@ -60,13 +61,19 @@ def add_common_test_args(parser: argparse.ArgumentParser):
help="Enable performance benchmarking mode. "
"Options: host (CPU time only), device (GPU time only), both (default)",
)
group.add_argument(
"--debug",
action="store_true",
help="Enable debug mode for detailed tensor comparison",
)
group.add_argument(
"--eq_nan",
action="store_true",
help="Enable equal_nan for tensor comparison",
)
group.add_argument(
"--verbose",
action="store_true",
......@@ -81,6 +88,7 @@ def add_common_test_args(parser: argparse.ArgumentParser):
help="Save test results to a JSON file. Defaults to 'test_report.json' if no filename provided.",
)
def get_args():
"""Parse command line arguments for operator testing"""
parser = argparse.ArgumentParser(
......@@ -100,9 +108,12 @@ Examples:
# Run with benchmarking - device timing only
python test_operator.py --nvidia --bench device
# Run with debug mode on multiple devices
# Run with basic debug mode on multiple devices
python test_operator.py --cpu --nvidia --debug
# Run with eq_nan debug mode to treat NaN as equal
python test_operator.py --cpu --nvidia --debug --eq_nan
# Run with verbose mode to stop on first error with full traceback
python test_operator.py --cpu --nvidia --verbose
......@@ -216,7 +227,7 @@ def get_test_devices(args):
devices_to_test.append(InfiniDeviceEnum.HYGON)
except ImportError:
print("Warning: Hygon DCU support not available")
if args.qy:
try:
# Iluvatar GPU detection
......
......@@ -9,6 +9,7 @@ import re
from . import TestConfig, TestRunner, get_args, get_test_devices
from .reporter import TestReporter
class GenericTestRunner:
"""Generic test runner that handles the common execution flow"""
......@@ -33,7 +34,8 @@ class GenericTestRunner:
bench=self.args.bench,
num_prerun=self.args.num_prerun,
num_iterations=self.args.num_iterations,
verbose=self.args.verbose, # Pass verbose flag to TestConfig
verbose=self.args.verbose,
equal_nan=self.args.eq_nan,
)
runner = TestRunner(self.operator_test.test_cases, config)
......@@ -53,9 +55,9 @@ class GenericTestRunner:
# summary_passed returns True if no tests failed (skipped/partial are OK)
summary_passed = runner.print_summary()
if getattr(self.args, 'save', None):
if getattr(self.args, "save", None):
self._save_report(runner)
# Both conditions must be True for overall success
# - has_no_failures: no test failures during execution
# - summary_passed: summary confirms no failures
......@@ -68,7 +70,7 @@ class GenericTestRunner:
0: All tests passed or were skipped/partial (no failures)
1: One or more tests failed
"""
success, runner = self.run()
success, runner = self.run()
sys.exit(0 if success else 1)
......@@ -77,15 +79,14 @@ class GenericTestRunner:
Helper method to collect metadata and trigger report saving.
"""
try:
# 1. Prepare metadata (Paths)
t_path = self._infer_op_path(self.operator_test.torch_operator, "torch")
i_path = self._infer_op_path(self.operator_test.infinicore_operator, "infinicore")
op_paths = {
"torch": t_path,
"infinicore": i_path
}
i_path = self._infer_op_path(
self.operator_test.infinicore_operator, "infinicore"
)
op_paths = {"torch": t_path, "infinicore": i_path}
# 2. Generate Report Entries
entries = TestReporter.prepare_report_entry(
......@@ -93,14 +94,16 @@ class GenericTestRunner:
test_cases=self.operator_test.test_cases,
args=self.args,
op_paths=op_paths,
results_list=runner.test_results
results_list=runner.test_results,
)
# 4. Save to File
TestReporter.save_all_results(self.args.save, entries)
except Exception as e:
import traceback; traceback.print_exc()
import traceback
traceback.print_exc()
print(f"⚠️ Failed to save report: {e}")
def _infer_op_path(self, method, lib_prefix):
......@@ -113,7 +116,9 @@ class GenericTestRunner:
# Regex to find 'lib.func' or 'lib.submodule.func'
# Matches: 'torch.add', 'torch.nn.functional.relu'
pattern = re.compile(rf"\b{lib_prefix}\.([a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)*)", re.IGNORECASE)
pattern = re.compile(
rf"\b{lib_prefix}\.([a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)*)", re.IGNORECASE
)
match = pattern.search(source)
if match:
# Return the matched string exactly as found in source code
......
......@@ -91,6 +91,7 @@ def print_discrepancy(
print(f" - Desired dtype: {expected.dtype}")
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
print(f" - Equal NaN: {equal_nan}")
print(
f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)"
)
......@@ -169,7 +170,7 @@ def convert_infinicore_to_torch(infini_result):
def compare_results(
infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False
infini_result, torch_result, atol=1e-5, rtol=1e-5, equal_nan=False, debug_mode=False
):
"""
Generic function to compare infinicore result with PyTorch reference result
......@@ -180,6 +181,7 @@ def compare_results(
torch_result: PyTorch tensor reference result (single or tuple)
atol: absolute tolerance (for floating-point only)
rtol: relative tolerance (for floating-point only)
equal_nan: whether to treat NaN as equal
debug_mode: whether to enable debug output
Returns:
......@@ -194,7 +196,9 @@ def compare_results(
all_match = True
for i, (infini_out, torch_out) in enumerate(zip(infini_result, torch_result)):
match = compare_results(infini_out, torch_out, atol, rtol, debug_mode)
match = compare_results(
infini_out, torch_out, atol, rtol, equal_nan, debug_mode
)
all_match = all_match and match
return all_match
......@@ -241,7 +245,13 @@ def compare_results(
# Debug mode: detailed comparison
if debug_mode:
debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol)
debug(
torch_result_from_infini,
torch_result,
atol=atol,
rtol=rtol,
equal_nan=equal_nan,
)
# Choose comparison method based on data type
if is_integer_dtype(torch_result_from_infini.dtype) or is_integer_dtype(
......@@ -257,10 +267,18 @@ def compare_results(
):
# Complex number comparison - compare real and imaginary parts separately
real_close = torch.allclose(
torch_result_from_infini.real, torch_result.real, atol=atol, rtol=rtol
torch_result_from_infini.real,
torch_result.real,
atol=atol,
rtol=rtol,
equal_nan=equal_nan,
)
imag_close = torch.allclose(
torch_result_from_infini.imag, torch_result.imag, atol=atol, rtol=rtol
torch_result_from_infini.imag,
torch_result.imag,
atol=atol,
rtol=rtol,
equal_nan=equal_nan,
)
result_equal = real_close and imag_close
if debug_mode and not result_equal:
......@@ -273,11 +291,15 @@ def compare_results(
else:
# Tolerance-based comparison for floating-point types
return torch.allclose(
torch_result_from_infini, torch_result, atol=atol, rtol=rtol
torch_result_from_infini,
torch_result,
atol=atol,
rtol=rtol,
equal_nan=equal_nan,
)
def create_test_comparator(config, atol, rtol, mode_name=""):
def create_test_comparator(config, atol, rtol, mode_name="", equal_nan=False):
"""
Create a test-specific comparison function
......@@ -286,6 +308,7 @@ def create_test_comparator(config, atol, rtol, mode_name=""):
atol: absolute tolerance (for floating-point only)
rtol: relative tolerance (for floating-point only)
mode_name: operation mode name for debug output
equal_nan: whether to treat NaN as equal
Returns:
callable: function that takes (infini_result, torch_result) and returns bool
......@@ -294,6 +317,9 @@ def create_test_comparator(config, atol, rtol, mode_name=""):
def compare_test_results(infini_result, torch_result):
if config.debug and mode_name:
print(f"\033[94mDEBUG INFO - {mode_name}:\033[0m")
print(
f"\033[94m Equal NaN: {'enabled' if equal_nan else 'disabled'}\033[0m"
)
# For integer types, override tolerance to require exact equality
actual_atol = atol
......@@ -316,6 +342,7 @@ def create_test_comparator(config, atol, rtol, mode_name=""):
torch_result,
atol=actual_atol,
rtol=actual_rtol,
equal_nan=equal_nan,
debug_mode=config.debug,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment