Unverified Commit bbf18d2c authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #758 from InfiniTensor/issue/757

issue/757 - support equal_nan in test debug
parents f53b8435 7a55b415
...@@ -30,8 +30,10 @@ class TestConfig: ...@@ -30,8 +30,10 @@ class TestConfig:
num_prerun=10, num_prerun=10,
num_iterations=1000, num_iterations=1000,
verbose=False, verbose=False,
equal_nan=False,
): ):
self.debug = debug self.debug = debug
self.equal_nan = equal_nan
self.bench = bench self.bench = bench
self.num_prerun = num_prerun self.num_prerun = num_prerun
self.num_iterations = num_iterations self.num_iterations = num_iterations
...@@ -540,7 +542,11 @@ class BaseOperatorTest(ABC): ...@@ -540,7 +542,11 @@ class BaseOperatorTest(ABC):
rtol = test_case.tolerance.get("rtol", 1e-3) rtol = test_case.tolerance.get("rtol", 1e-3)
compare_fn = create_test_comparator( compare_fn = create_test_comparator(
config, atol, rtol, f"{test_case.description} - output_{i}" config,
atol,
rtol,
f"{test_case.description} - output_{i}",
equal_nan=config.equal_nan,
) )
is_valid = compare_fn(infini_out, torch_out) is_valid = compare_fn(infini_out, torch_out)
...@@ -589,7 +595,11 @@ class BaseOperatorTest(ABC): ...@@ -589,7 +595,11 @@ class BaseOperatorTest(ABC):
rtol = test_case.tolerance.get("rtol", 1e-3) rtol = test_case.tolerance.get("rtol", 1e-3)
compare_fn = create_test_comparator( compare_fn = create_test_comparator(
config, atol, rtol, test_case.description config,
atol,
rtol,
test_case.description,
equal_nan=config.equal_nan,
) )
is_valid = compare_fn(infini_comparison, torch_comparison) is_valid = compare_fn(infini_comparison, torch_comparison)
......
...@@ -44,6 +44,7 @@ def get_hardware_args_group(parser): ...@@ -44,6 +44,7 @@ def get_hardware_args_group(parser):
return hardware_group return hardware_group
def add_common_test_args(parser: argparse.ArgumentParser): def add_common_test_args(parser: argparse.ArgumentParser):
""" """
Adds common test/execution arguments to the passed parser object. Adds common test/execution arguments to the passed parser object.
...@@ -67,6 +68,12 @@ def add_common_test_args(parser: argparse.ArgumentParser): ...@@ -67,6 +68,12 @@ def add_common_test_args(parser: argparse.ArgumentParser):
help="Enable debug mode for detailed tensor comparison", help="Enable debug mode for detailed tensor comparison",
) )
group.add_argument(
"--eq_nan",
action="store_true",
help="Enable equal_nan for tensor comparison",
)
group.add_argument( group.add_argument(
"--verbose", "--verbose",
action="store_true", action="store_true",
...@@ -81,6 +88,7 @@ def add_common_test_args(parser: argparse.ArgumentParser): ...@@ -81,6 +88,7 @@ def add_common_test_args(parser: argparse.ArgumentParser):
help="Save test results to a JSON file. Defaults to 'test_report.json' if no filename provided.", help="Save test results to a JSON file. Defaults to 'test_report.json' if no filename provided.",
) )
def get_args(): def get_args():
"""Parse command line arguments for operator testing""" """Parse command line arguments for operator testing"""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -100,9 +108,12 @@ Examples: ...@@ -100,9 +108,12 @@ Examples:
# Run with benchmarking - device timing only # Run with benchmarking - device timing only
python test_operator.py --nvidia --bench device python test_operator.py --nvidia --bench device
# Run with debug mode on multiple devices # Run with basic debug mode on multiple devices
python test_operator.py --cpu --nvidia --debug python test_operator.py --cpu --nvidia --debug
# Run with eq_nan debug mode to treat NaN as equal
python test_operator.py --cpu --nvidia --debug --eq_nan
# Run with verbose mode to stop on first error with full traceback # Run with verbose mode to stop on first error with full traceback
python test_operator.py --cpu --nvidia --verbose python test_operator.py --cpu --nvidia --verbose
......
...@@ -9,6 +9,7 @@ import re ...@@ -9,6 +9,7 @@ import re
from . import TestConfig, TestRunner, get_args, get_test_devices from . import TestConfig, TestRunner, get_args, get_test_devices
from .reporter import TestReporter from .reporter import TestReporter
class GenericTestRunner: class GenericTestRunner:
"""Generic test runner that handles the common execution flow""" """Generic test runner that handles the common execution flow"""
...@@ -33,7 +34,8 @@ class GenericTestRunner: ...@@ -33,7 +34,8 @@ class GenericTestRunner:
bench=self.args.bench, bench=self.args.bench,
num_prerun=self.args.num_prerun, num_prerun=self.args.num_prerun,
num_iterations=self.args.num_iterations, num_iterations=self.args.num_iterations,
verbose=self.args.verbose, # Pass verbose flag to TestConfig verbose=self.args.verbose,
equal_nan=self.args.eq_nan,
) )
runner = TestRunner(self.operator_test.test_cases, config) runner = TestRunner(self.operator_test.test_cases, config)
...@@ -53,7 +55,7 @@ class GenericTestRunner: ...@@ -53,7 +55,7 @@ class GenericTestRunner:
# summary_passed returns True if no tests failed (skipped/partial are OK) # summary_passed returns True if no tests failed (skipped/partial are OK)
summary_passed = runner.print_summary() summary_passed = runner.print_summary()
if getattr(self.args, 'save', None): if getattr(self.args, "save", None):
self._save_report(runner) self._save_report(runner)
# Both conditions must be True for overall success # Both conditions must be True for overall success
...@@ -80,12 +82,11 @@ class GenericTestRunner: ...@@ -80,12 +82,11 @@ class GenericTestRunner:
# 1. Prepare metadata (Paths) # 1. Prepare metadata (Paths)
t_path = self._infer_op_path(self.operator_test.torch_operator, "torch") t_path = self._infer_op_path(self.operator_test.torch_operator, "torch")
i_path = self._infer_op_path(self.operator_test.infinicore_operator, "infinicore") i_path = self._infer_op_path(
self.operator_test.infinicore_operator, "infinicore"
)
op_paths = { op_paths = {"torch": t_path, "infinicore": i_path}
"torch": t_path,
"infinicore": i_path
}
# 2. Generate Report Entries # 2. Generate Report Entries
entries = TestReporter.prepare_report_entry( entries = TestReporter.prepare_report_entry(
...@@ -93,14 +94,16 @@ class GenericTestRunner: ...@@ -93,14 +94,16 @@ class GenericTestRunner:
test_cases=self.operator_test.test_cases, test_cases=self.operator_test.test_cases,
args=self.args, args=self.args,
op_paths=op_paths, op_paths=op_paths,
results_list=runner.test_results results_list=runner.test_results,
) )
# 4. Save to File # 4. Save to File
TestReporter.save_all_results(self.args.save, entries) TestReporter.save_all_results(self.args.save, entries)
except Exception as e: except Exception as e:
import traceback; traceback.print_exc() import traceback
traceback.print_exc()
print(f"⚠️ Failed to save report: {e}") print(f"⚠️ Failed to save report: {e}")
def _infer_op_path(self, method, lib_prefix): def _infer_op_path(self, method, lib_prefix):
...@@ -113,7 +116,9 @@ class GenericTestRunner: ...@@ -113,7 +116,9 @@ class GenericTestRunner:
# Regex to find 'lib.func' or 'lib.submodule.func' # Regex to find 'lib.func' or 'lib.submodule.func'
# Matches: 'torch.add', 'torch.nn.functional.relu' # Matches: 'torch.add', 'torch.nn.functional.relu'
pattern = re.compile(rf"\b{lib_prefix}\.([a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)*)", re.IGNORECASE) pattern = re.compile(
rf"\b{lib_prefix}\.([a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+)*)", re.IGNORECASE
)
match = pattern.search(source) match = pattern.search(source)
if match: if match:
# Return the matched string exactly as found in source code # Return the matched string exactly as found in source code
......
...@@ -91,6 +91,7 @@ def print_discrepancy( ...@@ -91,6 +91,7 @@ def print_discrepancy(
print(f" - Desired dtype: {expected.dtype}") print(f" - Desired dtype: {expected.dtype}")
print(f" - Atol: {atol}") print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}") print(f" - Rtol: {rtol}")
print(f" - Equal NaN: {equal_nan}")
print( print(
f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)" f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)"
) )
...@@ -169,7 +170,7 @@ def convert_infinicore_to_torch(infini_result): ...@@ -169,7 +170,7 @@ def convert_infinicore_to_torch(infini_result):
def compare_results( def compare_results(
infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False infini_result, torch_result, atol=1e-5, rtol=1e-5, equal_nan=False, debug_mode=False
): ):
""" """
Generic function to compare infinicore result with PyTorch reference result Generic function to compare infinicore result with PyTorch reference result
...@@ -180,6 +181,7 @@ def compare_results( ...@@ -180,6 +181,7 @@ def compare_results(
torch_result: PyTorch tensor reference result (single or tuple) torch_result: PyTorch tensor reference result (single or tuple)
atol: absolute tolerance (for floating-point only) atol: absolute tolerance (for floating-point only)
rtol: relative tolerance (for floating-point only) rtol: relative tolerance (for floating-point only)
equal_nan: whether to treat NaN as equal
debug_mode: whether to enable debug output debug_mode: whether to enable debug output
Returns: Returns:
...@@ -194,7 +196,9 @@ def compare_results( ...@@ -194,7 +196,9 @@ def compare_results(
all_match = True all_match = True
for i, (infini_out, torch_out) in enumerate(zip(infini_result, torch_result)): for i, (infini_out, torch_out) in enumerate(zip(infini_result, torch_result)):
match = compare_results(infini_out, torch_out, atol, rtol, debug_mode) match = compare_results(
infini_out, torch_out, atol, rtol, equal_nan, debug_mode
)
all_match = all_match and match all_match = all_match and match
return all_match return all_match
...@@ -241,7 +245,13 @@ def compare_results( ...@@ -241,7 +245,13 @@ def compare_results(
# Debug mode: detailed comparison # Debug mode: detailed comparison
if debug_mode: if debug_mode:
debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol) debug(
torch_result_from_infini,
torch_result,
atol=atol,
rtol=rtol,
equal_nan=equal_nan,
)
# Choose comparison method based on data type # Choose comparison method based on data type
if is_integer_dtype(torch_result_from_infini.dtype) or is_integer_dtype( if is_integer_dtype(torch_result_from_infini.dtype) or is_integer_dtype(
...@@ -257,10 +267,18 @@ def compare_results( ...@@ -257,10 +267,18 @@ def compare_results(
): ):
# Complex number comparison - compare real and imaginary parts separately # Complex number comparison - compare real and imaginary parts separately
real_close = torch.allclose( real_close = torch.allclose(
torch_result_from_infini.real, torch_result.real, atol=atol, rtol=rtol torch_result_from_infini.real,
torch_result.real,
atol=atol,
rtol=rtol,
equal_nan=equal_nan,
) )
imag_close = torch.allclose( imag_close = torch.allclose(
torch_result_from_infini.imag, torch_result.imag, atol=atol, rtol=rtol torch_result_from_infini.imag,
torch_result.imag,
atol=atol,
rtol=rtol,
equal_nan=equal_nan,
) )
result_equal = real_close and imag_close result_equal = real_close and imag_close
if debug_mode and not result_equal: if debug_mode and not result_equal:
...@@ -273,11 +291,15 @@ def compare_results( ...@@ -273,11 +291,15 @@ def compare_results(
else: else:
# Tolerance-based comparison for floating-point types # Tolerance-based comparison for floating-point types
return torch.allclose( return torch.allclose(
torch_result_from_infini, torch_result, atol=atol, rtol=rtol torch_result_from_infini,
torch_result,
atol=atol,
rtol=rtol,
equal_nan=equal_nan,
) )
def create_test_comparator(config, atol, rtol, mode_name=""): def create_test_comparator(config, atol, rtol, mode_name="", equal_nan=False):
""" """
Create a test-specific comparison function Create a test-specific comparison function
...@@ -286,6 +308,7 @@ def create_test_comparator(config, atol, rtol, mode_name=""): ...@@ -286,6 +308,7 @@ def create_test_comparator(config, atol, rtol, mode_name=""):
atol: absolute tolerance (for floating-point only) atol: absolute tolerance (for floating-point only)
rtol: relative tolerance (for floating-point only) rtol: relative tolerance (for floating-point only)
mode_name: operation mode name for debug output mode_name: operation mode name for debug output
equal_nan: whether to treat NaN as equal
Returns: Returns:
callable: function that takes (infini_result, torch_result) and returns bool callable: function that takes (infini_result, torch_result) and returns bool
...@@ -294,6 +317,9 @@ def create_test_comparator(config, atol, rtol, mode_name=""): ...@@ -294,6 +317,9 @@ def create_test_comparator(config, atol, rtol, mode_name=""):
def compare_test_results(infini_result, torch_result): def compare_test_results(infini_result, torch_result):
if config.debug and mode_name: if config.debug and mode_name:
print(f"\033[94mDEBUG INFO - {mode_name}:\033[0m") print(f"\033[94mDEBUG INFO - {mode_name}:\033[0m")
print(
f"\033[94m Equal NaN: {'enabled' if equal_nan else 'disabled'}\033[0m"
)
# For integer types, override tolerance to require exact equality # For integer types, override tolerance to require exact equality
actual_atol = atol actual_atol = atol
...@@ -316,6 +342,7 @@ def create_test_comparator(config, atol, rtol, mode_name=""): ...@@ -316,6 +342,7 @@ def create_test_comparator(config, atol, rtol, mode_name=""):
torch_result, torch_result,
atol=actual_atol, atol=actual_atol,
rtol=actual_rtol, rtol=actual_rtol,
equal_nan=equal_nan,
debug_mode=config.debug, debug_mode=config.debug,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment