Commit 631a1fc7 authored by baominghelly's avatar baominghelly
Browse files

issue/877 - return saved_file in TestManager.test method

parent caa61e9e
......@@ -59,6 +59,9 @@ class TestExecutor:
test_summary = TestSummary()
test_summary.process_operator_result(result, test_results)
# Store saved report file if available
result.saved_file = runner.saved_file
except Exception as e:
result.success = False
result.error_message = str(e)
......
......@@ -43,6 +43,7 @@ class OperatorResult:
stdout: str = ""
stderr: str = ""
timing: TestTiming = field(default_factory=TestTiming)
saved_file: str = "" # Path to the saved report file
@property
def status_icon(self):
......@@ -262,11 +263,12 @@ class TestSummary:
def save_report(self, save_path):
"""
Delegates the actual writing to save_json_report.
Returns the actual file path that was saved (with timestamp).
"""
if not self.report_entries:
return
# Call the external utility
save_json_report(save_path, self.report_entries)
return None
# Call the external utility and get the actual saved path
return save_json_report(save_path, self.report_entries)
def _prepare_entry_logic(self, op_name, test_cases, args, op_paths, results_list):
"""
......
......@@ -20,6 +20,7 @@ class GenericTestRunner:
"""
self.operator_test = operator_test_class()
self.args = args or get_args()
self.saved_file = None # Store the path of saved report
def run(self):
"""Execute the complete test suite
......@@ -56,7 +57,7 @@ class GenericTestRunner:
summary_passed = runner.print_summary()
if getattr(self.args, "save", None):
self._save_report(runner)
self.saved_file = self._save_report(runner)
# Both conditions must be True for overall success
# - has_no_failures: no test failures during execution
......@@ -98,14 +99,15 @@ class GenericTestRunner:
results_list=runner.test_results,
)
# 4. Save to File
test_summary.save_report(self.args.save)
# 3. Save to File and return the file name
return test_summary.save_report(self.args.save)
except Exception as e:
import traceback
traceback.print_exc()
print(f"⚠️ Failed to save report: {e}")
return None
def _infer_op_path(self, method, lib_prefix):
"""
......
......@@ -138,6 +138,7 @@ class TestManager:
self.summary.print_header(display_location, len(test_files))
saved_files = []
for f, run_args in zip(test_files, test_configs):
# Inject prepared args (whether from JSON or Local global) into Executor
......@@ -146,6 +147,10 @@ class TestManager:
self.results.append(result)
self.summary.print_live_result(result)
# Collect saved report files
if hasattr(result, "saved_file") and result.saved_file:
saved_files.append(result.saved_file)
if result.success:
self._accumulate_timing(result.timing)
......@@ -160,7 +165,8 @@ class TestManager:
ops_dir=display_location,
total_expected=len(test_files),
)
return all_passed
return all_passed, saved_files
def _accumulate_timing(self, timing):
self.cumulative_timing.torch_host += timing.torch_host
......
......@@ -7,6 +7,9 @@ def save_json_report(save_path, total_results):
"""
Saves the report list to a JSON file with specific custom formatting
(Compact for short lines, Expanded for long lines).
Returns:
str: The actual file path that was saved (with timestamp), or None if failed.
"""
directory, filename = os.path.split(save_path)
name, ext = os.path.splitext(filename)
......@@ -85,11 +88,13 @@ def save_json_report(save_path, total_results):
f.write(f"{I4}{close_entry}\n")
f.write("]\n")
print(f" ✅ Saved.")
return final_path
except Exception as e:
import traceback
traceback.print_exc()
print(f" ❌ Save failed: {e}")
return None
def _write_field(f, key, value, indent, sub_indent, close_comma=""):
......
......@@ -130,12 +130,12 @@ def load_and_override_cases(load_paths, args):
for f_path in files_to_read:
try:
with open(f_path, 'r', encoding='utf-8') as f:
with open(f_path, "r", encoding="utf-8") as f:
data = json.load(f)
# Unify as a list to handle both single dict and list of dicts
current_batch = data if isinstance(data, list) else [data]
valid_batch = []
for item in current_batch:
# We only require the 'operator' field to identify the test case.
......@@ -143,11 +143,11 @@ def load_and_override_cases(load_paths, args):
valid_batch.append(item)
else:
skipped_count += 1
if valid_batch:
cases.extend(valid_batch)
loaded_count += 1
except Exception as e:
# Log warning only; do not crash the program on bad files to ensure flow continuity.
print(f"❌ Error loading {f_path.name}: {e}")
......@@ -173,7 +173,7 @@ def load_and_override_cases(load_paths, args):
cli_active_devices.append(device_name)
print(f"\n[Config Processing]")
for case in cases:
if "args" not in case or case["args"] is None:
case["args"] = {}
......@@ -283,9 +283,11 @@ def main():
print(f"Benchmark mode: {args.bench.upper()} timing")
# 3. Initialize and Execute
test_manager = TestManager(ops_dir=args.ops_dir, verbose=verbose, bench_mode=bench)
test_manager = TestManager(
ops_dir=args.ops_dir, verbose=verbose, bench_mode=bench
)
success = test_manager.test(json_cases_list=json_cases)
success, _ = test_manager.test(json_cases_list=json_cases)
# ==========================================================================
# Branch 2: Local Scan Mode
......@@ -330,7 +332,7 @@ def main():
ops_dir=args.ops_dir, verbose=args.verbose, bench_mode=args.bench
)
success = test_manager.test(
success, _ = test_manager.test(
target_ops=target_ops, global_exec_args=global_exec_args
)
sys.exit(0 if success else 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment