Commit 32bd2f82 authored by baominghelly's avatar baominghelly
Browse files

Issue/787 - class/method rename && move structs code to results file

parent 515b92fb
...@@ -9,11 +9,10 @@ from .config import ( ...@@ -9,11 +9,10 @@ from .config import (
) )
from .datatypes import to_torch_dtype, to_infinicore_dtype from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
from .results import TestTiming, OperatorResult, CaseResult, TestSummary
from .runner import GenericTestRunner from .runner import GenericTestRunner
from .tensor import TensorSpec, TensorInitializer from .tensor import TensorSpec, TensorInitializer
from .structs import TestTiming, OperatorResult, CaseResult from .executor import TestExecutor
from .summary import TestSummary
from .driver import TestDriver
from .utils.compare_utils import ( from .utils.compare_utils import (
compare_results, compare_results,
create_test_comparator, create_test_comparator,
...@@ -44,7 +43,7 @@ __all__ = [ ...@@ -44,7 +43,7 @@ __all__ = [
"TensorSpec", "TensorSpec",
"TestCase", "TestCase",
"TestConfig", "TestConfig",
"TestDriver", "TestExecutor",
"TestSummary", "TestSummary",
"TestRunner", "TestRunner",
"TestTiming", "TestTiming",
......
...@@ -8,7 +8,7 @@ import infinicore ...@@ -8,7 +8,7 @@ import infinicore
import traceback import traceback
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from .structs import CaseResult from .results import CaseResult
from .datatypes import to_torch_dtype, to_infinicore_dtype from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map from .devices import InfiniDeviceNames, torch_device_map
from .tensor import TensorSpec, TensorInitializer from .tensor import TensorSpec, TensorInitializer
......
...@@ -2,8 +2,7 @@ import sys ...@@ -2,8 +2,7 @@ import sys
import importlib.util import importlib.util
from io import StringIO from io import StringIO
from contextlib import contextmanager from contextlib import contextmanager
from .structs import OperatorResult from .results import OperatorResult, TestSummary
from .summary import TestSummary
@contextmanager @contextmanager
...@@ -18,8 +17,8 @@ def capture_output(): ...@@ -18,8 +17,8 @@ def capture_output():
sys.stdout, sys.stderr = old_out, old_err sys.stdout, sys.stderr = old_out, old_err
class TestDriver: class TestExecutor:
def drive(self, file_path) -> OperatorResult: def execute(self, file_path) -> OperatorResult:
result = OperatorResult(name=file_path.stem) result = OperatorResult(name=file_path.stem)
try: try:
...@@ -53,7 +52,6 @@ class TestDriver: ...@@ -53,7 +52,6 @@ class TestDriver:
test_summary = TestSummary() test_summary = TestSummary()
test_summary.process_operator_result(result, test_results) test_summary.process_operator_result(result, test_results)
# test_summary._extract_timing(result, test_results)
except Exception as e: except Exception as e:
result.success = False result.success = False
......
from typing import List, Dict, Any from typing import List, Dict, Any
from dataclasses import is_dataclass from dataclasses import dataclass, is_dataclass, field
from .devices import InfiniDeviceEnum from .devices import InfiniDeviceEnum
from .base import TensorSpec from .tensor import TensorSpec
from .utils.json_utils import save_json_report from .utils.json_utils import save_json_report
@dataclass
class CaseResult:
"""Test case result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time: float = 0.0
torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
@dataclass
class TestTiming:
"""Stores performance timing metrics."""
torch_host: float = 0.0
torch_device: float = 0.0
infini_host: float = 0.0
infini_device: float = 0.0
# Added field to support the logic in your print_summary
operators_tested: int = 0
@dataclass
class OperatorResult:
"""Stores the execution results of a single operator."""
name: str
success: bool = False
return_code: int = -1
error_message: str = ""
stdout: str = ""
stderr: str = ""
timing: TestTiming = field(default_factory=TestTiming)
@property
def status_icon(self):
if self.return_code == 0:
return "✅"
if self.return_code == -2:
return "⏭️"
if self.return_code == -3:
return "⚠️"
return "❌"
@property
def status_text(self):
if self.return_code == 0:
return "PASSED"
if self.return_code == -2:
return "SKIPPED"
if self.return_code == -3:
return "PARTIAL"
return "FAILED"
class TestSummary: class TestSummary:
""" """
Test summary: Test Summary class:
1. Aggregates results (Timing & Status calculation). 1. Aggregates results (Timing & Status calculation).
2. Handles Console Output (Live & Summary). 2. Handles Console Output (Live & Summary).
3. Handles File Reporting (Data Preparation). 3. Handles File Reporting (Data Preparation).
......
...@@ -7,7 +7,7 @@ import os ...@@ -7,7 +7,7 @@ import os
import inspect import inspect
import re import re
from . import TestConfig, TestRunner, get_args, get_test_devices from . import TestConfig, TestRunner, get_args, get_test_devices
from .summary import TestSummary from .results import TestSummary
class GenericTestRunner: class GenericTestRunner:
......
from dataclasses import dataclass, field
from typing import Any
@dataclass
class CaseResult:
"""Test case result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time: float = 0.0
torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
@dataclass
class TestTiming:
"""Stores performance timing metrics."""
torch_host: float = 0.0
torch_device: float = 0.0
infini_host: float = 0.0
infini_device: float = 0.0
# Added field to support the logic in your print_summary
operators_tested: int = 0
@dataclass
class OperatorResult:
"""Stores the execution results of a single operator."""
name: str
success: bool = False
return_code: int = -1
error_message: str = ""
stdout: str = ""
stderr: str = ""
timing: TestTiming = field(default_factory=TestTiming)
@property
def status_icon(self):
if self.return_code == 0:
return "✅"
if self.return_code == -2:
return "⏭️"
if self.return_code == -3:
return "⚠️"
return "❌"
@property
def status_text(self):
if self.return_code == 0:
return "PASSED"
if self.return_code == -2:
return "SKIPPED"
if self.return_code == -3:
return "PARTIAL"
return "FAILED"
...@@ -19,7 +19,7 @@ def save_json_report(save_path, total_results): ...@@ -19,7 +19,7 @@ def save_json_report(save_path, total_results):
print(f"💾 Saving to: {final_path}") print(f"💾 Saving to: {final_path}")
# Helper for JSON stringify to avoid repetition # Helper for JSON stringify to avoid repetition
def _j(obj): def _to_json(obj):
return json.dumps(obj, ensure_ascii=False) return json.dumps(obj, ensure_ascii=False)
try: try:
...@@ -58,16 +58,16 @@ def save_json_report(save_path, total_results): ...@@ -58,16 +58,16 @@ def save_json_report(save_path, total_results):
) )
if c_key in ["kwargs", "inputs"]: if c_key in ["kwargs", "inputs"]:
_write_smart_field( _write_field(
f, c_key, c_val, I16, I20, close_comma=c_comma f, c_key, c_val, I16, I20, close_comma=c_comma
) )
else: else:
f.write(f'{I16}"{c_key}": {_j(c_val)}{c_comma}\n') f.write(f'{I16}"{c_key}": {_to_json(c_val)}{c_comma}\n')
# Handle trailing comparison/tolerance fields uniformly # Handle trailing comparison/tolerance fields uniformly
if "comparison_target" in case_item: if "comparison_target" in case_item:
cmp = _j(case_item.get("comparison_target")) cmp = _to_json(case_item.get("comparison_target"))
tol = _j(case_item.get("tolerance")) tol = _to_json(case_item.get("tolerance"))
f.write( f.write(
f'{I16}"comparison_target": {cmp}, "tolerance": {tol}\n' f'{I16}"comparison_target": {cmp}, "tolerance": {tol}\n'
) )
...@@ -77,7 +77,7 @@ def save_json_report(save_path, total_results): ...@@ -77,7 +77,7 @@ def save_json_report(save_path, total_results):
f.write(f"{I8}]{comma}\n") f.write(f"{I8}]{comma}\n")
else: else:
# Standard top-level fields # Standard top-level fields
f.write(f"{I8}{_j(key)}: {_j(val)}{comma}\n") f.write(f"{I8}{_to_json(key)}: {_to_json(val)}{comma}\n")
close_entry = "}," if i < len(total_results) - 1 else "}" close_entry = "}," if i < len(total_results) - 1 else "}"
f.write(f"{I4}{close_entry}\n") f.write(f"{I4}{close_entry}\n")
...@@ -90,9 +90,9 @@ def save_json_report(save_path, total_results): ...@@ -90,9 +90,9 @@ def save_json_report(save_path, total_results):
print(f" ❌ Save failed: {e}") print(f" ❌ Save failed: {e}")
def _write_smart_field(f, key, value, indent, sub_indent, close_comma=""): def _write_field(f, key, value, indent, sub_indent, close_comma=""):
""" """
Internal Helper: Write a JSON field with smart wrapping. Internal Helper: Write a JSON field with wrapping.
""" """
# 1. Try Compact Mode # 1. Try Compact Mode
compact_json = json.dumps(value, ensure_ascii=False) compact_json = json.dumps(value, ensure_ascii=False)
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import infinicore import infinicore
from framework import ( from framework import (
BaseOperatorTest, BaseOperatorTest,
CaseResult,
TensorSpec, TensorSpec,
TestCase, TestCase,
GenericTestRunner, GenericTestRunner,
......
...@@ -6,7 +6,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) ...@@ -6,7 +6,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch import torch
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
from framework.tensor import TensorInitializer from framework.tensor import TensorInitializer
from framework.utils import ( from framework.utils.tensor_utils import (
convert_infinicore_to_torch, convert_infinicore_to_torch,
infinicore_tensor_from_torch, infinicore_tensor_from_torch,
to_torch_dtype, to_torch_dtype,
......
...@@ -222,8 +222,8 @@ class OpTest(BaseOperatorTest): ...@@ -222,8 +222,8 @@ class OpTest(BaseOperatorTest):
# Re-run operations with the same logits to get results for comparison # Re-run operations with the same logits to get results for comparison
# prepare_pytorch_inputs_and_kwargs will reuse self._current_logits if it exists # prepare_pytorch_inputs_and_kwargs will reuse self._current_logits if it exists
from framework.base import CaseResult from framework.results import CaseResult
from framework.utils import ( from framework.utils.tensor_utils import (
convert_infinicore_to_torch, convert_infinicore_to_torch,
infinicore_tensor_from_torch, infinicore_tensor_from_torch,
) )
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import infinicore import infinicore
from framework import ( from framework import (
BaseOperatorTest, BaseOperatorTest,
CaseResult,
TensorSpec, TensorSpec,
TestCase, TestCase,
GenericTestRunner, GenericTestRunner,
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import infinicore import infinicore
from framework import ( from framework import (
BaseOperatorTest, BaseOperatorTest,
CaseResult,
TensorSpec, TensorSpec,
TestCase, TestCase,
GenericTestRunner, GenericTestRunner,
......
...@@ -3,9 +3,8 @@ import argparse ...@@ -3,9 +3,8 @@ import argparse
from pathlib import Path from pathlib import Path
# Import components from the unified framework package # Import components from the unified framework package
from framework.driver import TestDriver from framework.executor import TestExecutor
from framework.summary import TestSummary from framework.results import TestSummary, TestTiming
from framework.structs import TestTiming
from framework import get_hardware_args_group, add_common_test_args from framework import get_hardware_args_group, add_common_test_args
...@@ -235,7 +234,7 @@ def main(): ...@@ -235,7 +234,7 @@ def main():
sys.exit(0) sys.exit(0)
# 2. Preparation # 2. Preparation
driver = TestDriver() executor = TestExecutor()
cumulative_timing = TestTiming() cumulative_timing = TestTiming()
test_summary = TestSummary(args.verbose, args.bench) test_summary = TestSummary(args.verbose, args.bench)
results = [] results = []
...@@ -244,7 +243,7 @@ def main(): ...@@ -244,7 +243,7 @@ def main():
# 3. Execution Loop # 3. Execution Loop
for f in test_files: for f in test_files:
result = driver.drive(f) result = executor.execute(f)
results.append(result) results.append(result)
# Real-time reporting and printing of stdout # Real-time reporting and printing of stdout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment