Commit 32bd2f82 authored by baominghelly's avatar baominghelly
Browse files

Issue/787 - class/method rename && move structs code to results file

parent 515b92fb
......@@ -9,11 +9,10 @@ from .config import (
)
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
from .results import TestTiming, OperatorResult, CaseResult, TestSummary
from .runner import GenericTestRunner
from .tensor import TensorSpec, TensorInitializer
from .structs import TestTiming, OperatorResult, CaseResult
from .summary import TestSummary
from .driver import TestDriver
from .executor import TestExecutor
from .utils.compare_utils import (
compare_results,
create_test_comparator,
......@@ -44,7 +43,7 @@ __all__ = [
"TensorSpec",
"TestCase",
"TestConfig",
"TestDriver",
"TestExecutor",
"TestSummary",
"TestRunner",
"TestTiming",
......
......@@ -8,7 +8,7 @@ import infinicore
import traceback
from abc import ABC, abstractmethod
from .structs import CaseResult
from .results import CaseResult
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map
from .tensor import TensorSpec, TensorInitializer
......
......@@ -2,8 +2,7 @@ import sys
import importlib.util
from io import StringIO
from contextlib import contextmanager
from .structs import OperatorResult
from .summary import TestSummary
from .results import OperatorResult, TestSummary
@contextmanager
......@@ -18,8 +17,8 @@ def capture_output():
sys.stdout, sys.stderr = old_out, old_err
class TestDriver:
def drive(self, file_path) -> OperatorResult:
class TestExecutor:
def execute(self, file_path) -> OperatorResult:
result = OperatorResult(name=file_path.stem)
try:
......@@ -53,7 +52,6 @@ class TestDriver:
test_summary = TestSummary()
test_summary.process_operator_result(result, test_results)
# test_summary._extract_timing(result, test_results)
except Exception as e:
result.success = False
......
from typing import List, Dict, Any
from dataclasses import is_dataclass
from dataclasses import dataclass, is_dataclass, field
from .devices import InfiniDeviceEnum
from .base import TensorSpec
from .tensor import TensorSpec
from .utils.json_utils import save_json_report
@dataclass
class CaseResult:
"""Test case result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time: float = 0.0
torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
@dataclass
class TestTiming:
"""Stores performance timing metrics."""
torch_host: float = 0.0
torch_device: float = 0.0
infini_host: float = 0.0
infini_device: float = 0.0
# Added field to support the logic in your print_summary
operators_tested: int = 0
@dataclass
class OperatorResult:
"""Stores the execution results of a single operator."""
name: str
success: bool = False
return_code: int = -1
error_message: str = ""
stdout: str = ""
stderr: str = ""
timing: TestTiming = field(default_factory=TestTiming)
@property
def status_icon(self):
if self.return_code == 0:
return "✅"
if self.return_code == -2:
return "⏭️"
if self.return_code == -3:
return "⚠️"
return "❌"
@property
def status_text(self):
if self.return_code == 0:
return "PASSED"
if self.return_code == -2:
return "SKIPPED"
if self.return_code == -3:
return "PARTIAL"
return "FAILED"
class TestSummary:
"""
Test summary:
Test Summary class:
1. Aggregates results (Timing & Status calculation).
2. Handles Console Output (Live & Summary).
3. Handles File Reporting (Data Preparation).
......
......@@ -7,7 +7,7 @@ import os
import inspect
import re
from . import TestConfig, TestRunner, get_args, get_test_devices
from .summary import TestSummary
from .results import TestSummary
class GenericTestRunner:
......
from dataclasses import dataclass, field
from typing import Any
@dataclass
class CaseResult:
"""Test case result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time: float = 0.0
torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
@dataclass
class TestTiming:
"""Stores performance timing metrics."""
torch_host: float = 0.0
torch_device: float = 0.0
infini_host: float = 0.0
infini_device: float = 0.0
# Added field to support the logic in your print_summary
operators_tested: int = 0
@dataclass
class OperatorResult:
"""Stores the execution results of a single operator."""
name: str
success: bool = False
return_code: int = -1
error_message: str = ""
stdout: str = ""
stderr: str = ""
timing: TestTiming = field(default_factory=TestTiming)
@property
def status_icon(self):
if self.return_code == 0:
return "✅"
if self.return_code == -2:
return "⏭️"
if self.return_code == -3:
return "⚠️"
return "❌"
@property
def status_text(self):
if self.return_code == 0:
return "PASSED"
if self.return_code == -2:
return "SKIPPED"
if self.return_code == -3:
return "PARTIAL"
return "FAILED"
......@@ -19,7 +19,7 @@ def save_json_report(save_path, total_results):
print(f"💾 Saving to: {final_path}")
# Helper for JSON stringify to avoid repetition
def _j(obj):
def _to_json(obj):
return json.dumps(obj, ensure_ascii=False)
try:
......@@ -58,16 +58,16 @@ def save_json_report(save_path, total_results):
)
if c_key in ["kwargs", "inputs"]:
_write_smart_field(
_write_field(
f, c_key, c_val, I16, I20, close_comma=c_comma
)
else:
f.write(f'{I16}"{c_key}": {_j(c_val)}{c_comma}\n')
f.write(f'{I16}"{c_key}": {_to_json(c_val)}{c_comma}\n')
# Handle trailing comparison/tolerance fields uniformly
if "comparison_target" in case_item:
cmp = _j(case_item.get("comparison_target"))
tol = _j(case_item.get("tolerance"))
cmp = _to_json(case_item.get("comparison_target"))
tol = _to_json(case_item.get("tolerance"))
f.write(
f'{I16}"comparison_target": {cmp}, "tolerance": {tol}\n'
)
......@@ -77,7 +77,7 @@ def save_json_report(save_path, total_results):
f.write(f"{I8}]{comma}\n")
else:
# Standard top-level fields
f.write(f"{I8}{_j(key)}: {_j(val)}{comma}\n")
f.write(f"{I8}{_to_json(key)}: {_to_json(val)}{comma}\n")
close_entry = "}," if i < len(total_results) - 1 else "}"
f.write(f"{I4}{close_entry}\n")
......@@ -90,9 +90,9 @@ def save_json_report(save_path, total_results):
print(f" ❌ Save failed: {e}")
def _write_smart_field(f, key, value, indent, sub_indent, close_comma=""):
def _write_field(f, key, value, indent, sub_indent, close_comma=""):
"""
Internal Helper: Write a JSON field with smart wrapping.
Internal Helper: Write a JSON field with wrapping.
"""
# 1. Try Compact Mode
compact_json = json.dumps(value, ensure_ascii=False)
......
......@@ -7,6 +7,7 @@ import torch
import infinicore
from framework import (
BaseOperatorTest,
CaseResult,
TensorSpec,
TestCase,
GenericTestRunner,
......
......@@ -6,7 +6,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
from framework.tensor import TensorInitializer
from framework.utils import (
from framework.utils.tensor_utils import (
convert_infinicore_to_torch,
infinicore_tensor_from_torch,
to_torch_dtype,
......
......@@ -222,8 +222,8 @@ class OpTest(BaseOperatorTest):
# Re-run operations with the same logits to get results for comparison
# prepare_pytorch_inputs_and_kwargs will reuse self._current_logits if it exists
from framework.base import CaseResult
from framework.utils import (
from framework.results import CaseResult
from framework.utils.tensor_utils import (
convert_infinicore_to_torch,
infinicore_tensor_from_torch,
)
......
......@@ -7,6 +7,7 @@ import torch
import infinicore
from framework import (
BaseOperatorTest,
CaseResult,
TensorSpec,
TestCase,
GenericTestRunner,
......
......@@ -7,6 +7,7 @@ import torch
import infinicore
from framework import (
BaseOperatorTest,
CaseResult,
TensorSpec,
TestCase,
GenericTestRunner,
......
......@@ -3,9 +3,8 @@ import argparse
from pathlib import Path
# Import components from the unified framework package
from framework.driver import TestDriver
from framework.summary import TestSummary
from framework.structs import TestTiming
from framework.executor import TestExecutor
from framework.results import TestSummary, TestTiming
from framework import get_hardware_args_group, add_common_test_args
......@@ -235,7 +234,7 @@ def main():
sys.exit(0)
# 2. Preparation
driver = TestDriver()
executor = TestExecutor()
cumulative_timing = TestTiming()
test_summary = TestSummary(args.verbose, args.bench)
results = []
......@@ -244,7 +243,7 @@ def main():
# 3. Execution Loop
for f in test_files:
result = driver.drive(f)
result = executor.execute(f)
results.append(result)
# Real-time reporting and printing of stdout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment