Unverified Commit 8d09630a authored by gongchensu's avatar gongchensu Committed by GitHub
Browse files

Merge branch 'demo131' into Issue/862

parents ab52dead 012df56c
......@@ -138,4 +138,34 @@ infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) {
infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
return freeDevice(ptr);
}
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphInstantiate(
infinirtGraphExec_t *graph_exec_ptr,
infinirtGraph_t graph,
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
} // namespace infinirt::musa
from .base import TestConfig, TestRunner, BaseOperatorTest
from .test_case import TestCase, TestResult
from .entities import TestCase
from .benchmark import BenchmarkUtils, BenchmarkResult
from .config import (
add_common_test_args,
......@@ -9,35 +9,49 @@ from .config import (
)
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
from .results import TestTiming, OperatorResult, CaseResult, TestSummary
from .runner import GenericTestRunner
from .test_manager import TestManager, TestCollector
from .tensor import TensorSpec, TensorInitializer
from .utils import (
from .executor import TestExecutor
from .utils.compare_utils import (
compare_results,
create_test_comparator,
debug,
get_tolerance,
)
from .utils.json_utils import save_json_report
from .utils.load_utils import TestGenerator
from .utils.tensor_utils import (
infinicore_tensor_from_torch,
rearrange_tensor,
convert_infinicore_to_torch,
rearrange_tensor,
is_broadcast,
is_integer_dtype,
is_complex_dtype,
is_floating_dtype,
is_integer_dtype,
)
__all__ = [
# Core types and classes
"BaseOperatorTest",
"CaseResult",
"GenericTestRunner",
"InfiniDeviceEnum",
"InfiniDeviceNames",
"OperatorResult",
"TestGenerator",
"TensorInitializer",
"TensorSpec",
"TestCase",
"TestCollector",
"TestConfig",
"TestResult",
"TestExecutor",
"TestManager",
"TestSummary",
"TestRunner",
"TestReporter",
"TestTiming",
# Core functions
"add_common_test_args",
"compare_results",
......@@ -50,6 +64,8 @@ __all__ = [
"get_tolerance",
"infinicore_tensor_from_torch",
"rearrange_tensor",
# Json utilites
"save_json_report",
# Utility functions
"to_infinicore_dtype",
"to_torch_dtype",
......
......@@ -8,15 +8,15 @@ import infinicore
import traceback
from abc import ABC, abstractmethod
from .test_case import TestCase, TestResult
from .results import CaseResult
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map
from .tensor import TensorSpec, TensorInitializer
from .utils import (
from .utils.tensor_utils import (
clone_torch_tensor,
create_test_comparator,
infinicore_tensor_from_torch,
)
from .utils.compare_utils import create_test_comparator
from .benchmark import BenchmarkUtils
......@@ -84,7 +84,7 @@ class TestRunner:
try:
print(f"{test_case}")
# Execute test and get TestResult object
# Execute test and get CaseResult object
test_result = test_func(device, test_case, self.config)
self.test_results.append(test_result)
......@@ -118,8 +118,8 @@ class TestRunner:
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
# Create a failed TestResult
failed_result = TestResult(
# Create a failed CaseResult
failed_result = CaseResult(
success=False,
return_code=-1,
error_message=str(e),
......@@ -342,7 +342,10 @@ class BaseOperatorTest(ABC):
for i, inp in enumerate(inputs):
if isinstance(inp, torch.Tensor):
# Clone only if this input will be used for comparison
if comparison_target == i:
if comparison_target == i or (
isinstance(comparison_target, (list, tuple))
and i in comparison_target
):
cloned_inp = clone_torch_tensor(inp)
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
cloned_tensors.append(cloned_inp)
......@@ -400,12 +403,12 @@ class BaseOperatorTest(ABC):
config: Test configuration
Returns:
TestResult: Test result object containing status and timing information
CaseResult: Test case result object containing status and timing information
"""
device_str = torch_device_map[device]
# Initialize test result
test_result = TestResult(
# Initialize test case result
test_result = CaseResult(
success=False,
return_code=-1, # Default to failure
test_case=test_case,
......@@ -508,7 +511,9 @@ class BaseOperatorTest(ABC):
# Handle multiple outputs comparison
# Determine what to compare based on comparison_target
if comparison_target is None:
if comparison_target is None or isinstance(
comparison_target, (list, tuple)
):
# Compare return values (out-of-place multiple outputs)
torch_comparison = torch_result
infini_comparison = infini_result
......@@ -573,7 +578,9 @@ class BaseOperatorTest(ABC):
# ==========================================================================
else:
# Determine comparison targets for single output
if comparison_target is None:
if comparison_target is None or isinstance(
comparison_target, (list, tuple)
):
# Compare return values (out-of-place)
torch_comparison = torch_result
infini_comparison = infini_result
......
......@@ -5,7 +5,7 @@ Benchmarking utilities for the InfiniCore testing framework
import time
import torch
import infinicore
from .utils import synchronize_device
from .utils.tensor_utils import synchronize_device
class BenchmarkUtils:
......
......@@ -24,6 +24,7 @@ def get_supported_hardware_platforms():
("--kunlun", "Kunlun XPUs (requires torch_xmlir)"),
("--hygon", "Hygon DCUs"),
("--qy", "QY GPUs"),
("--ali", "Ali PPU accelerators"),
]
......@@ -230,13 +231,21 @@ def get_test_devices(args):
if args.qy:
try:
# Iluvatar GPU detection
# QY GPU detection
import torch
devices_to_test.append(InfiniDeviceEnum.QY)
except ImportError:
print("Warning: QY GPU support not available")
if args.ali:
try:
import torch
devices_to_test.append(InfiniDeviceEnum.ALI)
except ImportError:
print("Warning: Ali PPU support not available")
# Default to CPU if no devices specified
if not devices_to_test:
devices_to_test = [InfiniDeviceEnum.CPU]
......
......@@ -9,6 +9,7 @@ class InfiniDeviceEnum:
KUNLUN = 7
HYGON = 8
QY = 9
ALI = 10
InfiniDeviceNames = {
......@@ -22,6 +23,7 @@ InfiniDeviceNames = {
InfiniDeviceEnum.QY: "Qy",
InfiniDeviceEnum.KUNLUN: "Kunlun",
InfiniDeviceEnum.HYGON: "Hygon",
InfiniDeviceEnum.ALI: "Ali",
}
torch_device_map = {
......@@ -35,4 +37,5 @@ torch_device_map = {
InfiniDeviceEnum.KUNLUN: "cuda",
InfiniDeviceEnum.HYGON: "cuda",
InfiniDeviceEnum.QY: "cuda",
InfiniDeviceEnum.ALI: "cuda",
}
......@@ -7,21 +7,6 @@ from typing import List, Dict, Any, Optional, Tuple
from .tensor import TensorSpec
@dataclass
class TestResult:
"""Test result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time: float = 0.0
torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
class TestCase:
"""Test case with all configuration included"""
......
import sys
import importlib.util
from io import StringIO
from contextlib import contextmanager
from .results import OperatorResult, TestSummary
@contextmanager
def capture_output():
"""Context manager: captures stdout and stderr."""
new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield new_out, new_err
finally:
sys.stdout, sys.stderr = old_out, old_err
class TestExecutor:
def execute(self, file_path, test_args) -> OperatorResult:
"""
Execute a test file dynamically.
Args:
file_path (Path): Path to the python test file.
test_args (argparse.Namespace): Arguments to pass to the runner. Must be provided.
"""
result = OperatorResult(name=file_path.stem)
try:
# 1. Dynamically import the module
module = self._import_module(file_path)
# 2. Look for TestRunner
if not hasattr(module, "GenericTestRunner"):
raise ImportError("No GenericTestRunner found in module")
# 3. Look for TestClass (subclass of BaseOperatorTest)
test_class = self._find_test_class(module)
if not test_class:
raise ImportError("No BaseOperatorTest subclass found")
test_instance = test_class()
runner_class = module.GenericTestRunner
runner = runner_class(test_instance.__class__, args=test_args)
# 4. Execute and capture output
with capture_output() as (out, err):
success, internal_runner = runner.run()
# 5. Populate results
result.success = success
result.stdout = out.getvalue()
result.stderr = err.getvalue()
# Extract detailed results from internal_runner
test_results = internal_runner.get_test_results() if internal_runner else []
test_summary = TestSummary()
test_summary.process_operator_result(result, test_results)
# Store saved report file if available
result.saved_file = runner.saved_file
except Exception as e:
result.success = False
result.error_message = str(e)
result.stderr += f"\nExecutor Error: {str(e)}"
result.return_code = -1
return result
def _import_module(self, path):
module_name = f"op_test_{path.stem}"
spec = importlib.util.spec_from_file_location(module_name, path)
if not spec or not spec.loader:
raise ImportError(f"Could not load spec from {path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def _find_test_class(self, module):
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, type) and hasattr(attr, "__bases__"):
# Simple check for base class name
if any("BaseOperatorTest" in str(b) for b in attr.__bases__):
return attr
return None
import json
import os
from datetime import datetime
from typing import List, Dict, Any, Union
from dataclasses import is_dataclass
from .base import TensorSpec
from .devices import InfiniDeviceEnum
class TestReporter:
"""
Handles report generation and file saving for test results.
"""
@staticmethod
def prepare_report_entry(
op_name: str,
test_cases: List[Any],
args: Any,
op_paths: Dict[str, str],
results_list: List[Any]
) -> List[Dict[str, Any]]:
"""
Combines static test case info with dynamic execution results.
"""
# 1. Normalize results
results_map = {}
if isinstance(results_list, list):
results_map = {i: res for i, res in enumerate(results_list)}
elif isinstance(results_list, dict):
results_map = results_list
else:
results_map = {0: results_list} if results_list else {}
# 2. Global Args
global_args = {
k: getattr(args, k)
for k in ["bench", "num_prerun", "num_iterations", "verbose", "debug"]
if hasattr(args, k)
}
grouped_entries: Dict[int, Dict[str, Any]] = {}
# 3. Iterate Test Cases
for idx, tc in enumerate(test_cases):
res = results_map.get(idx)
dev_id = getattr(res, "device", 0) if res else 0
# --- A. Initialize Group ---
if dev_id not in grouped_entries:
device_id_map = {v: k for k, v in vars(InfiniDeviceEnum).items() if not k.startswith("_")}
dev_str = device_id_map.get(dev_id, str(dev_id))
grouped_entries[dev_id] = {
"operator": op_name,
"device": dev_str,
"torch_op": op_paths.get("torch") or "unknown",
"infinicore_op": op_paths.get("infinicore") or "unknown",
"args": global_args,
"testcases": []
}
# --- B. Build Kwargs ---
display_kwargs = {}
# B1. Process existing kwargs
for k, v in tc.kwargs.items():
# Handle Inplace: "out": index -> "out": "input_name"
if k == "out" and isinstance(v, int):
if 0 <= v < len(tc.inputs):
display_kwargs[k] = tc.inputs[v].name
else:
display_kwargs[k] = f"Invalid_Index_{v}"
else:
display_kwargs[k] = (TestReporter._spec_to_dict(v) if isinstance(v, TensorSpec) else v)
# B2. Inject Outputs into Kwargs
if hasattr(tc, "output_specs") and tc.output_specs:
for i, spec in enumerate(tc.output_specs):
display_kwargs[f"out_{i}"] = TestReporter._spec_to_dict(spec)
elif tc.output_spec:
if "out" not in display_kwargs:
display_kwargs["out"] = TestReporter._spec_to_dict(tc.output_spec)
# --- C. Build Test Case Dictionary ---
case_data = {
"description": tc.description,
"inputs": [TestReporter._spec_to_dict(i) for i in tc.inputs],
"kwargs": display_kwargs,
"comparison_target": tc.comparison_target,
"tolerance": tc.tolerance,
}
# --- D. Inject Result ---
if res:
case_data["result"] = TestReporter._fmt_result(res)
else:
case_data["result"] = {"status": {"success": False, "error": "No result"}}
grouped_entries[dev_id]["testcases"].append(case_data)
return list(grouped_entries.values())
@staticmethod
def save_all_results(save_path: str, total_results: List[Dict[str, Any]]):
"""
Saves the report list to a JSON file with specific custom formatting
"""
directory, filename = os.path.split(save_path)
name, ext = os.path.splitext(filename)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
final_path = os.path.join(directory, f"{name}_{timestamp}{ext}")
# Define indentation levels for cleaner code
indent_4 = ' ' * 4
indent_8 = ' ' * 8
indent_12 = ' ' * 12
indent_16 = ' ' * 16
indent_20 = ' ' * 20
print(f"💾 Saving to: {final_path}")
try:
with open(final_path, "w", encoding="utf-8") as f:
f.write("[\n")
for i, entry in enumerate(total_results):
f.write(f"{indent_4}{{\n")
keys = list(entry.keys())
for j, key in enumerate(keys):
val = entry[key]
comma = "," if j < len(keys) - 1 else ""
# -------------------------------------------------
# Special Handling for 'testcases' list formatting
# -------------------------------------------------
if key == "testcases" and isinstance(val, list):
f.write(f'{indent_8}"{key}": [\n')
for c_idx, case_item in enumerate(val):
f.write(f"{indent_12}{{\n")
case_keys = list(case_item.keys())
for k_idx, c_key in enumerate(case_keys):
c_val = case_item[c_key]
# [Logic A] Skip fields we merged manually after 'kwargs'
if c_key in ["comparison_target", "tolerance"]:
continue
# Check comma for standard logic (might be overridden below)
c_comma = "," if k_idx < len(case_keys) - 1 else ""
# [Logic B] Handle 'kwargs' + Grouped Fields
if c_key == "kwargs":
# 1. Use Helper for kwargs (Fill/Flow logic)
TestReporter._write_smart_field(
f, c_key, c_val, indent_16, indent_20, close_comma=","
)
# 2. Write subsequent comparison_target and tolerance (on a new line)
cmp_v = json.dumps(case_item.get("comparison_target"), ensure_ascii=False)
tol_v = json.dumps(case_item.get("tolerance"), ensure_ascii=False)
remaining_keys = [k for k in case_keys[k_idx+1:] if k not in ("comparison_target", "tolerance")]
line_comma = "," if remaining_keys else ""
f.write(f'{indent_16}"comparison_target": {cmp_v}, "tolerance": {tol_v}{line_comma}\n')
continue
# [Logic C] Handle 'inputs' (Smart Wrap)
if c_key == "inputs" and isinstance(c_val, list):
TestReporter._write_smart_field(
f, c_key, c_val, indent_16, indent_20, close_comma=c_comma
)
continue
# [Logic D] Standard fields (description, result, output_spec, etc.)
else:
c_val_str = json.dumps(c_val, ensure_ascii=False)
f.write(f'{indent_16}"{c_key}": {c_val_str}{c_comma}\n')
close_comma = "," if c_idx < len(val) - 1 else ""
f.write(f"{indent_12}}}{close_comma}\n")
f.write(f"{indent_8}]{comma}\n")
# -------------------------------------------------
# Standard top-level fields (operator, args, etc.)
# -------------------------------------------------
else:
k_str = json.dumps(key, ensure_ascii=False)
v_str = json.dumps(val, ensure_ascii=False)
f.write(f"{indent_8}{k_str}: {v_str}{comma}\n")
if i < len(total_results) - 1:
f.write(f"{indent_4}}},\n")
else:
f.write(f"{indent_4}}}\n")
f.write("]\n")
print(f" ✅ Saved (Structure Matched).")
except Exception as e:
import traceback; traceback.print_exc()
print(f" ❌ Save failed: {e}")
# --- Internal Helpers ---
@staticmethod
def _write_smart_field(f, key, value, indent, sub_indent, close_comma=""):
"""
Helper to write a JSON field (List or Dict) with smart wrapping.
- If compact length <= 180: Write on one line.
- If > 180: Use 'Fill/Flow' mode (multiple items per line, wrap when line is full).
"""
# 1. Try Compact Mode
compact_json = json.dumps(value, ensure_ascii=False)
if len(compact_json) <= 180:
f.write(f'{indent}"{key}": {compact_json}{close_comma}\n')
return
# 2. Fill/Flow Mode
is_dict = isinstance(value, dict)
open_char = '{' if is_dict else '['
close_char = '}' if is_dict else ']'
f.write(f'{indent}"{key}": {open_char}')
# Normalize items for iteration
if is_dict:
items = list(value.items())
else:
items = value # List
# Initialize current line length tracking
# Length includes indent + "key": [
current_len = len(indent) + len(f'"{key}": {open_char}')
for i, item in enumerate(items):
# Format individual item string
if is_dict:
k, v = item
val_str = json.dumps(v, ensure_ascii=False)
item_str = f'"{k}": {val_str}'
else:
item_str = json.dumps(item, ensure_ascii=False)
is_last = (i == len(items) - 1)
item_comma = "" if is_last else ", "
# Predict new length: current + item + comma
if current_len + len(item_str) + len(item_comma) > 180:
# Wrap to new line
f.write(f'\n{sub_indent}')
current_len = len(sub_indent)
f.write(f'{item_str}{item_comma}')
current_len += len(item_str) + len(item_comma)
f.write(f'{close_char}{close_comma}\n')
@staticmethod
def _spec_to_dict(s):
return {
"name": getattr(s, "name", "unknown"),
"shape": list(s.shape) if s.shape else None,
"dtype": str(s.dtype).split(".")[-1],
"strides": list(s.strides) if s.strides else None,
}
@staticmethod
def _fmt_result(res):
if not (is_dataclass(res) or hasattr(res, "success")):
return str(res)
get_time = lambda k: round(getattr(res, k, 0.0), 4)
return {
"status": {
"success": getattr(res, "success", False),
"error": getattr(res, "error_message", ""),
},
"perf_ms": {
"torch": {
"host": get_time("torch_host_time"),
"device": get_time("torch_device_time"),
},
"infinicore": {
"host": get_time("infini_host_time"),
"device": get_time("infini_device_time"),
},
},
}
from typing import List, Dict, Any
from dataclasses import dataclass, is_dataclass, field
from .devices import InfiniDeviceEnum
from .tensor import TensorSpec
from .utils.json_utils import save_json_report
@dataclass
class CaseResult:
"""Test case result data structure"""
success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time: float = 0.0
torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None
@dataclass
class TestTiming:
"""Stores performance timing metrics."""
torch_host: float = 0.0
torch_device: float = 0.0
infini_host: float = 0.0
infini_device: float = 0.0
# Added field to support the logic in your print_summary
operators_tested: int = 0
@dataclass
class OperatorResult:
"""Stores the execution results of a single operator."""
name: str
success: bool = False
return_code: int = -1
error_message: str = ""
stdout: str = ""
stderr: str = ""
timing: TestTiming = field(default_factory=TestTiming)
saved_file: str = "" # Path to the saved report file
@property
def status_icon(self):
if self.return_code == 0:
return "✅"
if self.return_code == -2:
return "⏭️"
if self.return_code == -3:
return "⚠️"
return "❌"
@property
def status_text(self):
if self.return_code == 0:
return "PASSED"
if self.return_code == -2:
return "SKIPPED"
if self.return_code == -3:
return "PARTIAL"
return "FAILED"
class TestSummary:
"""
Test Summary class:
1. Aggregates results (Timing & Status calculation).
2. Handles Console Output (Live & Summary).
3. Handles File Reporting (Data Preparation).
"""
def __init__(self, verbose=False, bench_mode=None):
self.verbose = verbose
self.bench_mode = bench_mode
self.report_entries = [] # Cache for JSON report
# =========================================================
# Part 1: Result Aggregation
# =========================================================
def process_operator_result(self, op_result, sub_results: List):
"""
Updates the OperatorResult object in-place.
"""
if not sub_results:
op_result.return_code = -1
return
# 1. Analyze Return Code (Status)
if op_result.success:
op_result.return_code = 0
else:
has_failures = any(r.return_code == -1 for r in sub_results)
has_partial = any(r.return_code == -3 for r in sub_results)
has_skipped = any(r.return_code == -2 for r in sub_results)
if has_failures:
op_result.return_code = -1
elif has_partial:
op_result.return_code = -3
elif has_skipped:
op_result.return_code = -2
else:
op_result.return_code = -1
# 2. Extract Timing (Aggregation)
t = op_result.timing
t.torch_host = sum(r.torch_host_time for r in sub_results)
t.torch_device = sum(r.torch_device_time for r in sub_results)
t.infini_host = sum(r.infini_host_time for r in sub_results)
t.infini_device = sum(r.infini_device_time for r in sub_results)
t.operators_tested = len(sub_results)
# =========================================================
# Part 2: Console Output (View)
# =========================================================
def list_tests(self, collector):
ops_dir = collector.ops_dir
operators = collector.get_available_operators()
if operators:
print(f"Available operator test files in {ops_dir}:")
for operator in operators:
print(f" - {operator}")
print(f"\nTotal: {len(operators)} operators")
else:
print(f"No valid operator tests found in {ops_dir}")
raw_files = collector.get_raw_python_files()
if raw_files:
print(
f"\n💡 Debug Hint: Found Python files but they are not valid tests:"
)
print(f" {raw_files}")
def print_header(self, ops_dir, count):
print(f"InfiniCore Operator Test Runner")
print(f"Directory: {ops_dir}")
print(f"Tests found: {count}\n")
def print_live_result(self, result):
print(
f"{result.status_icon} {result.name}: {result.status_text} (code: {result.return_code})"
)
if result.stdout:
print(result.stdout.rstrip())
if result.stderr:
print("\nSTDERR:", result.stderr.rstrip())
if result.error_message:
print(f"💥 Error: {result.error_message}")
if result.stdout or result.stderr or self.verbose:
print("-" * 40)
def print_summary(self, results, cumulative_timing, ops_dir, total_expected=0):
print(f"\n{'='*80}\nCUMULATIVE TEST SUMMARY\n{'='*80}")
passed = [r for r in results if r.return_code == 0]
failed = [r for r in results if r.return_code == -1]
skipped = [r for r in results if r.return_code == -2]
partial = [r for r in results if r.return_code == -3]
total = len(results)
print(f"Total tests run: {total}")
if total_expected > 0 and total < total_expected:
print(f"Total tests expected: {total_expected}")
print(f"Tests not executed: {total_expected - total}")
print(f"Passed: {len(passed)}")
print(f"Failed: {len(failed)}")
if skipped:
print(f"Skipped: {len(skipped)}")
if partial:
print(f"Partial: {len(partial)}")
# 1. Benchmark
if cumulative_timing:
self._print_timing(cumulative_timing)
# 2. Lists
if passed:
self._print_op_list("✅ PASSED OPERATORS", passed)
else:
print(f"\n✅ PASSED OPERATORS: None")
if failed:
self._print_op_list("❌ FAILED OPERATORS", failed)
if skipped:
self._print_op_list("⏭️ SKIPPED OPERATORS", skipped)
if partial:
self._print_op_list("⚠️ PARTIAL IMPLEMENTATIONS", partial)
# 3. Verdict
if total > 0:
executed_tests = total - len(skipped)
if executed_tests > 0:
success_rate = len(passed) / executed_tests * 100
print(f"\nSuccess rate: {success_rate:.1f}%")
if not failed:
if skipped or partial:
print(f"\n⚠️ Tests completed with some operators not fully implemented")
else:
print(f"\n🎉 All tests passed!")
else:
print(f"\n{len(failed)} tests failed")
if not failed and (skipped or partial):
print(f"\n⚠️ Note: Some operators are not fully implemented")
print(f" Run individual tests for details on missing implementations")
if self.verbose and failed:
print(
f"\n💡 Verbose mode tip: Use individual test commands for detailed debugging:"
)
for r in failed[:3]:
file_path = ops_dir / (r.name + ".py")
print(f" python {file_path} --verbose")
if len(failed) > 3:
print(f" ... (and {len(failed) - 3} others)")
return len(failed) == 0
def _print_timing(self, t):
print(f"{'-'*40}")
if hasattr(t, "operators_tested") and t.operators_tested > 0:
print(f"BENCHMARK SUMMARY ({t.operators_tested} cases):")
if self.bench_mode in ["host", "both"]:
print(f" [Host] PyTorch: {t.torch_host:10.3f} ms")
print(f" [Host] InfiniCore: {t.infini_host:10.3f} ms")
if self.bench_mode in ["device", "both"]:
print(f" [Device] PyTorch: {t.torch_device:10.3f} ms")
print(f" [Device] InfiniCore: {t.infini_device:10.3f} ms")
print(f"{'-'*40}")
def _print_op_list(self, title, result_list):
print(f"\n{title} ({len(result_list)}):")
names = [r.name for r in result_list]
for i in range(0, len(names), 10):
print(" " + ", ".join(names[i : i + 10]))
# =========================================================
# Part 3: Report Generation
# =========================================================
def collect_report_entry(self, op_name, test_cases, args, op_paths, results_list):
"""
Prepares the data and adds it to the internal list.
"""
entry = self._prepare_entry_logic(
op_name, test_cases, args, op_paths, results_list
)
self.report_entries.extend(entry)
def save_report(self, save_path):
"""
Delegates the actual writing to save_json_report.
Returns the actual file path that was saved (with timestamp).
"""
if not self.report_entries:
return None
# Call the external utility and get the actual saved path
return save_json_report(save_path, self.report_entries)
def _prepare_entry_logic(self, op_name, test_cases, args, op_paths, results_list):
"""
Combines static test case info with dynamic execution results.
Refactored to reduce duplication.
"""
# 1. Normalize results
results_map = (
results_list
if isinstance(results_list, dict)
else {i: res for i, res in enumerate(results_list or [])}
)
# 2. Global Args
global_args = {
k: getattr(args, k)
for k in ["bench", "num_prerun", "num_iterations", "verbose", "debug"]
if hasattr(args, k)
}
grouped_entries = {}
# Cache device enum map
device_id_map = {
v: k for k, v in vars(InfiniDeviceEnum).items() if not k.startswith("_")
}
for idx, tc in enumerate(test_cases):
res = results_map.get(idx)
dev_id = getattr(res, "device", 0) if res else 0
# --- A. Initialize Group ---
if dev_id not in grouped_entries:
grouped_entries[dev_id] = {
"operator": op_name,
"device": device_id_map.get(dev_id, str(dev_id)),
"torch_op": op_paths.get("torch", "unknown"),
"infinicore_op": op_paths.get("infinicore", "unknown"),
"args": global_args,
"testcases": [],
}
# --- B. Helpers for Spec Processing ---
def process_spec(spec, default_name):
final_name = self._resolve_name(spec, default_name)
# Call internal method (no need for external converters file)
return self._spec_to_dict(spec, name=final_name)
# --- C. Build Inputs ---
processed_inputs = [
process_spec(inp, f"in_{i}") for i, inp in enumerate(tc.inputs)
]
# --- D. Build Kwargs ---
display_kwargs = {}
for k, v in tc.kwargs.items():
if k == "out" and isinstance(v, int):
# Handle Inplace Index
if 0 <= v < len(tc.inputs):
display_kwargs[k] = self._resolve_name(tc.inputs[v], f"in_{v}")
else:
display_kwargs[k] = f"Invalid_Index_{v}"
elif isinstance(v, TensorSpec):
display_kwargs[k] = process_spec(v, v.name)
else:
display_kwargs[k] = v
# --- E. Inject Outputs ---
if getattr(tc, "output_specs", None):
for i, spec in enumerate(tc.output_specs):
display_kwargs[f"out_{i}"] = process_spec(spec, f"out_{i}")
elif tc.output_spec and "out" not in display_kwargs:
display_kwargs["out"] = process_spec(tc.output_spec, "out")
# --- F. Assemble Case Data ---
case_data = {
"description": tc.description,
"inputs": processed_inputs,
"kwargs": display_kwargs,
"comparison_target": tc.comparison_target,
"tolerance": tc.tolerance,
"result": (
self._fmt_result(res)
if res
else {"status": {"success": False, "error": "No result"}}
),
}
grouped_entries[dev_id]["testcases"].append(case_data)
return list(grouped_entries.values())
# --- Internal Helpers ---
def _resolve_name(self, obj, default_name):
return getattr(obj, "name", None) or default_name
def _spec_to_dict(self, s, name=None):
return {
"name": name if name else getattr(s, "name", "unknown"),
"shape": list(s.shape) if s.shape else None,
"dtype": str(s.dtype).split(".")[-1],
"strides": list(s.strides) if s.strides else None,
}
def _fmt_result(self, res):
if not (is_dataclass(res) or hasattr(res, "success")):
return str(res)
get_time = lambda k: round(getattr(res, k, 0.0), 4)
return {
"status": {
"success": getattr(res, "success", False),
"error": getattr(res, "error_message", ""),
},
"perf_ms": {
"torch": {
"host": get_time("torch_host_time"),
"device": get_time("torch_device_time"),
},
"infinicore": {
"host": get_time("infini_host_time"),
"device": get_time("infini_device_time"),
},
},
}
......@@ -7,7 +7,7 @@ import os
import inspect
import re
from . import TestConfig, TestRunner, get_args, get_test_devices
from .reporter import TestReporter
from .results import TestSummary
class GenericTestRunner:
......@@ -20,6 +20,7 @@ class GenericTestRunner:
"""
self.operator_test = operator_test_class()
self.args = args or get_args()
self.saved_file = None # Store the path of saved report
def run(self):
"""Execute the complete test suite
......@@ -56,7 +57,7 @@ class GenericTestRunner:
summary_passed = runner.print_summary()
if getattr(self.args, "save", None):
self._save_report(runner)
self.saved_file = self._save_report(runner)
# Both conditions must be True for overall success
# - has_no_failures: no test failures during execution
......@@ -89,7 +90,8 @@ class GenericTestRunner:
op_paths = {"torch": t_path, "infinicore": i_path}
# 2. Generate Report Entries
entries = TestReporter.prepare_report_entry(
test_summary = TestSummary()
entries = test_summary.collect_report_entry(
op_name=self.operator_test.operator_name,
test_cases=self.operator_test.test_cases,
args=self.args,
......@@ -97,14 +99,15 @@ class GenericTestRunner:
results_list=runner.test_results,
)
# 4. Save to File
TestReporter.save_all_results(self.args.save, entries)
# 3. Save to File and return the file name
return test_summary.save_report(self.args.save)
except Exception as e:
import traceback
traceback.print_exc()
print(f"⚠️ Failed to save report: {e}")
return None
def _infer_op_path(self, method, lib_prefix):
"""
......
......@@ -3,7 +3,7 @@ import math
from pathlib import Path
from .datatypes import to_torch_dtype
from .devices import torch_device_map
from .utils import is_integer_dtype, is_complex_dtype
from .utils.tensor_utils import is_integer_dtype, is_complex_dtype
class TensorInitializer:
......@@ -60,7 +60,12 @@ class TensorInitializer:
# Handle real floating-point types
if mode == TensorInitializer.RANDOM:
return torch.rand(shape, dtype=torch_dtype, device=torch_device_str)
scale = kwargs.get("scale", 1.0)
bias = kwargs.get("bias", 0.0)
return (
torch.rand(shape, dtype=torch_dtype, device=torch_device_str) * scale
+ bias
)
elif mode == TensorInitializer.ZEROS:
return torch.zeros(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.ONES:
......
import sys
import argparse
import tempfile
from pathlib import Path
from .executor import TestExecutor
from .results import TestSummary, TestTiming
from .utils.load_utils import TestGenerator
class TestCollector:
"""
Responsible for scanning and verifying operator test files.
"""
def __init__(self, ops_dir_path=None):
self.ops_dir = self._resolve_dir(ops_dir_path)
def _resolve_dir(self, path):
if path:
p = Path(path)
if p.exists():
return p
# Fallback: 'ops' directory relative to the project root
fallback = Path(__file__).parent.parent / "ops"
return fallback if fallback.exists() else None
def get_available_operators(self):
if not self.ops_dir:
return []
files = self.scan()
return sorted([f.stem for f in files])
def get_raw_python_files(self):
if not self.ops_dir or not self.ops_dir.exists():
return []
files = list(self.ops_dir.glob("*.py"))
return [
f.name for f in files if f.name != "run.py" and not f.name.startswith("__")
]
def scan(self, specific_ops=None):
if not self.ops_dir or not self.ops_dir.exists():
return []
files = list(self.ops_dir.glob("*.py"))
target_ops_set = set(specific_ops) if specific_ops else None
valid_files = []
for f in files:
if f.name.startswith("_") or f.name == "run.py":
continue
if target_ops_set and f.stem not in target_ops_set:
continue
if self._is_operator_test(f):
valid_files.append(f)
return valid_files
def _is_operator_test(self, file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return "infinicore" in content and (
"BaseOperatorTest" in content or "GenericTestRunner" in content
)
except:
return False
class TestManager:
"""
High-level API to execute operator tests.
Encapsulates the test loop, timing aggregation, and reporting.
"""
def __init__(self, ops_dir=None, verbose=False, bench_mode=None):
self.collector = TestCollector(ops_dir)
self.verbose = verbose
self.bench_mode = bench_mode
# Initialize components
self.executor = TestExecutor()
self.summary = TestSummary(verbose, bench_mode)
self.cumulative_timing = TestTiming()
self.results = []
def test(self, target_ops=None, json_cases_list=None, global_exec_args=None):
"""
Args:
target_ops: List of target operators for local scan
json_cases_list: List of test cases in JSON mode
global_exec_args (argparse.Namespace): Unified argument object passed to Executor in local scan mode
"""
with tempfile.TemporaryDirectory() as temp_dir_str:
test_files = []
test_configs = [] # Stores args for each file
display_location = ""
# =================================================
# 1. Mode Selection
# =================================================
if json_cases_list:
# [Mode A] Dynamic Execution (JSON)
print(f"🚀 Mode: Dynamic Execution")
project_root = getattr(
self, "project_root", Path(__file__).resolve().parent.parent
)
generator = TestGenerator(project_root=str(project_root))
# Generate files
dynamic_paths = generator.generate(json_cases_list, temp_dir_str)
test_files = [Path(p) for p in dynamic_paths]
# Convert JSON dict to Namespace
for case_data in json_cases_list:
# run.py has sanitized the data, convert directly to Namespace
ns = argparse.Namespace(**case_data.get("args", {}))
test_configs.append(ns)
display_location = f"Dynamic ({len(test_files)} cases)"
else:
# [Mode B] Local File Scan
# print(f"📂 Mode: Local File Scan")
test_files = self.collector.scan(target_ops)
display_location = str(self.collector.ops_dir)
# ✅ Key Logic: Apply global_exec_args passed from run.py to all files
# If global_exec_args is None (run.py should theoretically fill this), executor falls back to default behavior
test_configs = [global_exec_args] * len(test_files)
# =================================================
# 2. Execution Loop
# =================================================
if not test_files:
print(f"No valid tests found in {display_location}")
return True
self.summary.print_header(display_location, len(test_files))
saved_files = []
for f, run_args in zip(test_files, test_configs):
# Inject prepared args (whether from JSON or Local global) into Executor
result = self.executor.execute(f, test_args=run_args)
self.results.append(result)
self.summary.print_live_result(result)
# Collect saved report files
if hasattr(result, "saved_file") and result.saved_file:
saved_files.append(result.saved_file)
if result.success:
self._accumulate_timing(result.timing)
if self.verbose and not result.success:
print("\nStopping due to failure in verbose mode.")
break
# Summary
all_passed = self.summary.print_summary(
self.results,
self.cumulative_timing if self.bench_mode else None,
ops_dir=display_location,
total_expected=len(test_files),
)
return all_passed, saved_files
def _accumulate_timing(self, timing):
self.cumulative_timing.torch_host += timing.torch_host
self.cumulative_timing.infini_host += timing.infini_host
self.cumulative_timing.torch_device += timing.torch_device
self.cumulative_timing.infini_device += timing.infini_device
self.cumulative_timing.operators_tested += 1
import torch
import time
import infinicore
import numpy as np
from .datatypes import to_infinicore_dtype, to_torch_dtype
def synchronize_device(torch_device):
"""Device synchronization"""
if torch_device == "cuda":
torch.cuda.synchronize()
elif torch_device == "npu":
torch.npu.synchronize()
elif torch_device == "mlu":
torch.mlu.synchronize()
elif torch_device == "musa":
torch.musa.synchronize()
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
Debug function to compare two tensors and print differences
"""
# Handle complex types by converting to real representation for comparison
if actual.is_complex() or desired.is_complex():
actual = torch.view_as_real(actual)
desired = torch.view_as_real(desired)
elif actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32)
desired = desired.to(torch.float32)
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
import numpy as np
np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
)
def print_discrepancy(
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True
):
"""Print detailed tensor differences"""
if actual.shape != expected.shape:
raise ValueError("Tensors must have the same shape to compare.")
import torch
import sys
is_terminal = sys.stdout.isatty()
actual_isnan = torch.isnan(actual)
expected_isnan = torch.isnan(expected)
# Calculate difference mask
nan_mismatch = (
actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
)
diff_mask = nan_mismatch | (
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
)
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected
# Display formatting
col_width = [18, 20, 20, 20]
decimal_places = [0, 12, 12, 12]
total_width = sum(col_width) + sum(decimal_places)
def add_color(text, color_code):
if is_terminal:
return f"\033[{color_code}m{text}\033[0m"
else:
return text
if verbose:
for idx in diff_indices:
index_tuple = tuple(idx.tolist())
actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}"
expected_str = (
f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}"
)
delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}"
print(
f" > Index: {str(index_tuple):<{col_width[0]}}"
f"actual: {add_color(actual_str, 31)}"
f"expect: {add_color(expected_str, 32)}"
f"delta: {add_color(delta_str, 33)}"
)
print(f" - Actual dtype: {actual.dtype}")
print(f" - Desired dtype: {expected.dtype}")
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
print(f" - Equal NaN: {equal_nan}")
print(
f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)"
)
print(
f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}"
)
print(
f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}"
)
print(
f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}"
)
print("-" * total_width)
return diff_indices
def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3):
"""
Get tolerance settings based on data type
"""
tolerance = tolerance_map.get(
tensor_dtype, {"atol": default_atol, "rtol": default_rtol}
)
return tolerance["atol"], tolerance["rtol"]
def clone_torch_tensor(torch_tensor):
cloned = torch_tensor.clone().detach()
if not torch_tensor.is_contiguous():
cloned = rearrange_tensor(cloned, torch_tensor.stride())
return cloned
def infinicore_tensor_from_torch(torch_tensor):
infini_device = infinicore.device(torch_tensor.device.type, 0)
if torch_tensor.is_contiguous():
return infinicore.from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
else:
return infinicore.strided_from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
list(torch_tensor.stride()),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
def convert_infinicore_to_torch(infini_result):
"""
Convert infinicore tensor to PyTorch tensor for comparison
Args:
infini_result: infinicore tensor result
Returns:
torch.Tensor: PyTorch tensor with infinicore data
"""
torch_result_from_infini = torch.zeros(
infini_result.shape,
dtype=to_torch_dtype(infini_result.dtype),
device=infini_result.device.type,
)
if not infini_result.is_contiguous():
torch_result_from_infini = rearrange_tensor(
torch_result_from_infini, infini_result.stride()
)
temp_tensor = infinicore_tensor_from_torch(torch_result_from_infini)
temp_tensor.copy_(infini_result)
return torch_result_from_infini
import sys
from ..datatypes import to_torch_dtype
from .tensor_utils import (
convert_infinicore_to_torch,
is_integer_dtype,
is_complex_dtype,
)
def compare_results(
......@@ -351,89 +189,104 @@ def create_test_comparator(config, atol, rtol, mode_name="", equal_nan=False):
return compare_test_results
def rearrange_tensor(tensor, new_strides):
def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3):
"""
Given a PyTorch tensor and a list of new strides, return a new PyTorch tensor with the given strides.
Get tolerance settings based on data type
"""
import torch
tolerance = tolerance_map.get(
tensor_dtype, {"atol": default_atol, "rtol": default_rtol}
)
return tolerance["atol"], tolerance["rtol"]
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
Debug function to compare two tensors and print differences
"""
# Handle complex types by converting to real representation for comparison
if actual.is_complex() or desired.is_complex():
actual = torch.view_as_real(actual)
desired = torch.view_as_real(desired)
elif actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32)
desired = desired.to(torch.float32)
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
shape = tensor.shape
new_size = [0] * len(shape)
left = 0
right = 0
for i in range(len(shape)):
if new_strides[i] >= 0:
new_size[i] = (shape[i] - 1) * new_strides[i] + 1
right += new_strides[i] * (shape[i] - 1)
else: # TODO: Support negative strides in the future
# new_size[i] = (shape[i] - 1) * (-new_strides[i]) + 1
# left += new_strides[i] * (shape[i] - 1)
raise ValueError("Negative strides are not supported yet")
# Create a new tensor with zeros
new_tensor = torch.zeros(
(right - left + 1,), dtype=tensor.dtype, device=tensor.device
import numpy as np
np.testing.assert_allclose(
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
)
# Generate indices for original tensor based on original strides
indices = [torch.arange(s) for s in shape]
mesh = torch.meshgrid(*indices, indexing="ij")
# Flatten indices for linear indexing
linear_indices = [m.flatten() for m in mesh]
def print_discrepancy(
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True
):
"""Print detailed tensor differences"""
if actual.shape != expected.shape:
raise ValueError("Tensors must have the same shape to compare.")
# Calculate new positions based on new strides
new_positions = sum(
linear_indices[i] * new_strides[i] for i in range(len(shape))
).to(tensor.device)
offset = -left
new_positions += offset
import torch
import sys
is_terminal = sys.stdout.isatty()
actual_isnan = torch.isnan(actual)
expected_isnan = torch.isnan(expected)
# Copy the original data to the new tensor
new_tensor.reshape(-1).index_add_(0, new_positions, tensor.reshape(-1))
new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides))
# Calculate difference mask
nan_mismatch = (
actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
)
diff_mask = nan_mismatch | (
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
)
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected
return new_tensor
# Display formatting
col_width = [18, 20, 20, 20]
decimal_places = [0, 12, 12, 12]
total_width = sum(col_width) + sum(decimal_places)
def add_color(text, color_code):
if is_terminal:
return f"\033[{color_code}m{text}\033[0m"
else:
return text
def is_broadcast(strides):
"""
Check if strides indicate a broadcasted tensor
if verbose:
for idx in diff_indices:
index_tuple = tuple(idx.tolist())
actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}"
expected_str = (
f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}"
)
delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}"
print(
f" > Index: {str(index_tuple):<{col_width[0]}}"
f"actual: {add_color(actual_str, 31)}"
f"expect: {add_color(expected_str, 32)}"
f"delta: {add_color(delta_str, 33)}"
)
Args:
strides: Tensor strides or None
print(f" - Actual dtype: {actual.dtype}")
print(f" - Desired dtype: {expected.dtype}")
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
print(f" - Equal NaN: {equal_nan}")
print(
f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)"
)
print(
f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}"
)
print(
f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}"
)
print(
f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}"
)
print("-" * total_width)
Returns:
bool: True if the tensor is broadcasted (has zero strides)
"""
if strides is None:
return False
return any(s == 0 for s in strides)
def is_integer_dtype(dtype):
"""Check if dtype is integer type"""
return dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.bool,
]
def is_complex_dtype(dtype):
"""Check if dtype is complex type"""
return dtype in [torch.complex64, torch.complex128]
def is_floating_dtype(dtype):
"""Check if dtype is floating-point type"""
return dtype in [
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
]
return diff_indices
import json
import os
from datetime import datetime
def save_json_report(save_path, total_results):
"""
Saves the report list to a JSON file with specific custom formatting
(Compact for short lines, Expanded for long lines).
Returns:
str: The actual file path that was saved (with timestamp), or None if failed.
"""
directory, filename = os.path.split(save_path)
name, ext = os.path.splitext(filename)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
final_path = os.path.join(directory, f"{name}_{timestamp}{ext}")
# Define Indentation
I4, I8, I12, I16, I20 = " " * 4, " " * 8, " " * 12, " " * 16, " " * 20
print(f"💾 Saving to: {final_path}")
# Helper for JSON stringify to avoid repetition
def _to_json(obj):
return json.dumps(obj, ensure_ascii=False)
try:
with open(final_path, "w", encoding="utf-8") as f:
f.write("[\n")
for i, entry in enumerate(total_results):
f.write(f"{I4}{{\n")
keys = list(entry.keys())
for j, key in enumerate(keys):
val = entry[key]
comma = "," if j < len(keys) - 1 else ""
# Special handling for 'testcases' list
if key == "testcases" and isinstance(val, list):
f.write(f'{I8}"{key}": [\n')
for c_idx, case_item in enumerate(val):
f.write(f"{I12}{{\n")
case_keys = list(case_item.keys())
# Filter out keys that we handle specially at the end
standard_keys = [
k
for k in case_keys
if k not in ["comparison_target", "tolerance"]
]
for k_idx, c_key in enumerate(standard_keys):
c_val = case_item[c_key]
# Determine comma logic
c_comma = (
","
if k_idx < len(standard_keys) - 1
or "comparison_target" in case_item
else ""
)
if c_key in ["kwargs", "inputs"]:
_write_field(
f, c_key, c_val, I16, I20, close_comma=c_comma
)
else:
f.write(
f'{I16}"{c_key}": {_to_json(c_val)}{c_comma}\n'
)
# Handle trailing comparison/tolerance fields uniformly
if "comparison_target" in case_item:
cmp = _to_json(case_item.get("comparison_target"))
tol = _to_json(case_item.get("tolerance"))
f.write(
f'{I16}"comparison_target": {cmp}, "tolerance": {tol}\n'
)
close_case = "," if c_idx < len(val) - 1 else ""
f.write(f"{I12}}}{close_case}\n")
f.write(f"{I8}]{comma}\n")
else:
# Standard top-level fields
f.write(f"{I8}{_to_json(key)}: {_to_json(val)}{comma}\n")
close_entry = "}," if i < len(total_results) - 1 else "}"
f.write(f"{I4}{close_entry}\n")
f.write("]\n")
print(f" ✅ Saved.")
return final_path
except Exception as e:
import traceback
traceback.print_exc()
print(f" ❌ Save failed: {e}")
return None
def _write_field(f, key, value, indent, sub_indent, close_comma=""):
"""
Internal Helper: Write a JSON field with wrapping.
"""
# 1. Try Compact Mode
compact_json = json.dumps(value, ensure_ascii=False)
if len(compact_json) <= 180:
f.write(f'{indent}"{key}": {compact_json}{close_comma}\n')
return
# 2. Fill/Flow Mode
is_dict = isinstance(value, dict)
open_char = "{" if is_dict else "["
close_char = "}" if is_dict else "]"
f.write(f'{indent}"{key}": {open_char}')
if is_dict:
items = list(value.items())
else:
items = value
current_len = len(indent) + len(f'"{key}": {open_char}')
for i, item in enumerate(items):
if is_dict:
k, v = item
val_str = json.dumps(v, ensure_ascii=False)
item_str = f'"{k}": {val_str}'
else:
item_str = json.dumps(item, ensure_ascii=False)
is_last = i == len(items) - 1
item_comma = "" if is_last else ", "
if current_len + len(item_str) + len(item_comma) > 180:
f.write(f"\n{sub_indent}")
current_len = len(sub_indent)
f.write(f"{item_str}{item_comma}")
current_len += len(item_str) + len(item_comma)
f.write(f"{close_char}{close_comma}\n")
import json
import os
import sys
import pprint
from pathlib import Path
# ==============================================================================
# OpTest Templates
# ==============================================================================
_TEST_FILE_TEMPLATE = r'''import sys
import os
import json
import pprint
# Path Injection
sys.path.insert(0, r"{project_root}")
import torch
import torch.nn.functional
import infinicore
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
)
# ==============================================================================
# Injected Configuration
# ==============================================================================
_OP_CONFIG = {op_config_json}
# ==============================================================================
# Helpers
# ==============================================================================
def _parse_dtype(dtype_str):
"""Convert string dtype to framework/torch object."""
if hasattr(infinicore, dtype_str): return getattr(infinicore, dtype_str)
if hasattr(torch, dtype_str): return getattr(torch, dtype_str)
return dtype_str
def _dict_to_spec(spec_dict):
"""Convert JSON dict to TensorSpec object."""
if not isinstance(spec_dict, dict): return spec_dict
return TensorSpec(
shape=tuple(spec_dict['shape']),
dtype=_parse_dtype(spec_dict['dtype']),
name=spec_dict.get('name'),
strides=tuple(spec_dict['strides']) if spec_dict.get('strides') else None
)
def parse_test_cases():
"""Parse JSON testcases into framework TestCase objects."""
test_cases = []
raw_cases = _OP_CONFIG.get("testcases", [])
for case in raw_cases:
# 1. Parse Inputs and build name-to-index map
input_specs = []
name_to_index = {}
for idx, inp in enumerate(case.get('inputs', [])):
spec = _dict_to_spec(inp)
input_specs.append(spec)
if spec.name:
name_to_index[spec.name] = idx
# 2. Parse Kwargs
kwargs = {}
for k, v in case.get('kwargs', {}).items():
# Resolve string references (e.g., "out": "a" -> "out": 0)
if k == "out" and isinstance(v, str) and v in name_to_index:
kwargs[k] = name_to_index[v]
elif isinstance(v, dict) and "shape" in v:
kwargs[k] = _dict_to_spec(v)
else:
kwargs[k] = v
# 3. Handle explicit output spec
output_spec = None
if "out" in kwargs and isinstance(kwargs["out"], TensorSpec):
output_spec = kwargs.pop("out")
# 4. Tolerance & Comparison Target
tol_dict = case.get('tolerance', {})
tolerance = {"atol": tol_dict.get("atol", 0), "rtol": tol_dict.get("rtol", 1e-3)}
comp_target = case.get('comparison_target')
if isinstance(comp_target, str) and comp_target in name_to_index:
comp_target = name_to_index[comp_target]
test_cases.append(TestCase(
inputs=input_specs,
kwargs=kwargs,
output_spec=output_spec,
comparison_target=comp_target,
tolerance=tolerance,
description=case.get('description', "Dynamic Case")
))
return test_cases
class OpTest(BaseOperatorTest):
def __init__(self):
super().__init__(_OP_CONFIG.get("operator", "UnknownOp"))
def get_test_cases(self):
"""Returns the list of parsed test cases."""
return parse_test_cases()
def _resolve_kwargs(self, args, kwargs):
"""Resolves index-based 'out' arguments to actual Tensors."""
resolved_kwargs = kwargs.copy()
if "out" in resolved_kwargs:
val = resolved_kwargs["out"]
if isinstance(val, int) and 0 <= val < len(args):
resolved_kwargs["out"] = args[val]
return resolved_kwargs
def torch_operator(self, *args, **kwargs):
"""PyTorch operator implementation."""
{torch_method_body}
def infinicore_operator(self, *args, **kwargs):
"""InfiniCore operator implementation."""
{infini_method_body}
def main():
"""Execution entry point."""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
'''
class TestGenerator:
def __init__(self, project_root):
self.project_root = os.path.abspath(project_root)
def generate(self, json_list, output_dir):
generated_files = []
for idx, op_config in enumerate(json_list):
op_name = op_config.get("operator", "Unknown")
file_name = f"test_{op_name}_{idx}.py"
file_path = os.path.join(output_dir, file_name)
# 1. Fetch operator names
torch_op_name = op_config.get("torch_op")
infinicore_op_name = op_config.get("infinicore_op")
# 2. Prepare method bodies
# If the op name is provided, generate the return statement.
# If it's None/null, use 'pass' to avoid syntax errors.
make_body = lambda name, tag: (
f"return {name}(*args, **self._resolve_kwargs(args, kwargs))"
if name else f"pass # {tag} is null, skipping implementation"
)
torch_body = make_body(torch_op_name, "torch_op")
infini_body = make_body(infinicore_op_name, "infinicore_op")
# 3. Fill the template
config_str = pprint.pformat(op_config, indent=4, width=120)
file_content = _TEST_FILE_TEMPLATE.replace("{op_config_json}", config_str)
file_content = file_content.replace("{project_root}", self.project_root)
# Injected Method Bodies
file_content = file_content.replace("{torch_method_body}", torch_body)
file_content = file_content.replace("{infini_method_body}", infini_body)
with open(file_path, "w", encoding="utf-8") as f:
f.write(file_content)
generated_files.append(file_path)
return generated_files
import torch
import infinicore
from ..datatypes import to_infinicore_dtype, to_torch_dtype
# =================================================================
# Device & Synchronization
# =================================================================
def synchronize_device(torch_device):
"""Device synchronization"""
if torch_device == "cuda":
torch.cuda.synchronize()
elif torch_device == "npu":
torch.npu.synchronize()
elif torch_device == "mlu":
torch.mlu.synchronize()
elif torch_device == "musa":
torch.musa.synchronize()
# =================================================================
# Tensor Operations & Conversions
# =================================================================
def clone_torch_tensor(torch_tensor):
cloned = torch_tensor.clone().detach()
if not torch_tensor.is_contiguous():
cloned = rearrange_tensor(cloned, torch_tensor.stride())
return cloned
def infinicore_tensor_from_torch(torch_tensor):
infini_device = infinicore.device(torch_tensor.device.type, 0)
if torch_tensor.is_contiguous():
return infinicore.from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
else:
return infinicore.strided_from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
list(torch_tensor.stride()),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
def convert_infinicore_to_torch(infini_result):
"""
Convert infinicore tensor to PyTorch tensor for comparison
Args:
infini_result: infinicore tensor result
Returns:
torch.Tensor: PyTorch tensor with infinicore data
"""
torch_result_from_infini = torch.zeros(
infini_result.shape,
dtype=to_torch_dtype(infini_result.dtype),
device=infini_result.device.type,
)
if not infini_result.is_contiguous():
torch_result_from_infini = rearrange_tensor(
torch_result_from_infini, infini_result.stride()
)
temp_tensor = infinicore_tensor_from_torch(torch_result_from_infini)
temp_tensor.copy_(infini_result)
return torch_result_from_infini
def rearrange_tensor(tensor, new_strides):
"""
Given a PyTorch tensor and a list of new strides, return a new PyTorch tensor with the given strides.
"""
import torch
shape = tensor.shape
new_size = [0] * len(shape)
left = 0
right = 0
for i in range(len(shape)):
if new_strides[i] >= 0:
new_size[i] = (shape[i] - 1) * new_strides[i] + 1
right += new_strides[i] * (shape[i] - 1)
else: # TODO: Support negative strides in the future
# new_size[i] = (shape[i] - 1) * (-new_strides[i]) + 1
# left += new_strides[i] * (shape[i] - 1)
raise ValueError("Negative strides are not supported yet")
# Create a new tensor with zeros
new_tensor = torch.zeros(
(right - left + 1,), dtype=tensor.dtype, device=tensor.device
)
# Generate indices for original tensor based on original strides
indices = [torch.arange(s) for s in shape]
mesh = torch.meshgrid(*indices, indexing="ij")
# Flatten indices for linear indexing
linear_indices = [m.flatten() for m in mesh]
# Calculate new positions based on new strides
new_positions = sum(
linear_indices[i] * new_strides[i] for i in range(len(shape))
).to(tensor.device)
offset = -left
new_positions += offset
# Copy the original data to the new tensor
new_tensor.reshape(-1).index_add_(0, new_positions, tensor.reshape(-1))
new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides))
return new_tensor
def is_broadcast(strides):
"""
Check if strides indicate a broadcasted tensor
Args:
strides: Tensor strides or None
Returns:
bool: True if the tensor is broadcasted (has zero strides)
"""
if strides is None:
return False
return any(s == 0 for s in strides)
# =================================================================
# Type Checks (Moved here to avoid circular imports in check.py)
# =================================================================
def is_integer_dtype(dtype):
"""Check if dtype is integer type"""
return dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.bool,
]
def is_complex_dtype(dtype):
"""Check if dtype is complex type"""
return dtype in [torch.complex64, torch.complex128]
def is_floating_dtype(dtype):
"""Check if dtype is floating-point type"""
return dtype in [
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
]
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
from framework import BaseOperatorTest, GenericTestRunner, TensorSpec, TestCase
from framework.tensor import TensorInitializer
import infinicore
# Test cases format: (nlayers, batch_size, hidden_size, nhead, nkvhead, dim, seqlen, past_seqlen, max_seqlen)
_TEST_CASES_DATA = [
(28, 1, 3584, 28, 28, 128, 1, 256, 512),
]
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-4, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-4, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 1e-4, "rtol": 5e-2},
}
_TENSOR_DTYPES = [infinicore.float16, infinicore.float32, infinicore.bfloat16]
def parse_test_cases():
cases = []
for (
nlayers,
batch_size,
hidden_size,
nhead,
nkvhead,
dim,
seqlen,
past_seqlen,
max_seqlen,
) in _TEST_CASES_DATA:
for dtype in _TENSOR_DTYPES:
tol = _TOLERANCE_MAP[dtype]
hidden_states = TensorSpec.from_tensor(
(batch_size, seqlen, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2
)
pos_ids = TensorSpec.from_tensor(
(batch_size, seqlen),
dtype=infinicore.int64,
init_mode=TensorInitializer.RANDINT,
low=0,
high=max_seqlen,
)
k_cache = TensorSpec.from_tensor(
(nlayers, batch_size, nkvhead, max_seqlen, dim),
dtype=dtype,
scale=1e-1,
bias=-5e-2,
)
v_cache = TensorSpec.from_tensor(
(nlayers, batch_size, nkvhead, max_seqlen, dim),
dtype=dtype,
scale=1e-1,
bias=-5e-2,
)
q_proj_w = TensorSpec.from_tensor(
(nhead * dim, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2
)
k_proj_w = TensorSpec.from_tensor(
(nkvhead * dim, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2
)
v_proj_w = TensorSpec.from_tensor(
(nkvhead * dim, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2
)
o_proj_w = TensorSpec.from_tensor(
(hidden_size, nhead * dim), dtype=dtype, scale=1e-1, bias=-5e-2
)
norm_w = TensorSpec.from_tensor(
(hidden_size,), dtype=dtype, scale=1e-1, bias=-5e-2
)
sin_table = TensorSpec.from_tensor(
(max_seqlen, dim // 2), dtype=dtype, scale=1e-1, bias=-5e-2
)
cos_table = TensorSpec.from_tensor(
(max_seqlen, dim // 2), dtype=dtype, scale=1e-1, bias=-5e-2
)
# Out-of-place
cases.append(
TestCase(
inputs=[
hidden_states,
pos_ids,
nhead,
nkvhead,
dim,
past_seqlen,
nlayers,
k_cache,
v_cache,
q_proj_w,
k_proj_w,
v_proj_w,
o_proj_w,
norm_w,
sin_table,
cos_table,
],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tol,
description="Graph",
)
)
return cases
def torch_rope(
q: torch.Tensor,
k: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
pos_ids: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
q, k: [B, H, S, D]
sin, cos: [max_S, D//2]
pos_ids: [B, S]
"""
def rotate_half(x: torch.Tensor) -> torch.Tensor:
# x: [..., head_dim]
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
return torch.stack((-x_odd, x_even), dim=-1).flatten(-2)
B, H, S, D = q.shape
assert D % 2 == 0
# Gather sin/cos by position
# -> [B, S, D//2]
sin = sin[pos_ids]
cos = cos[pos_ids]
# Expand to broadcast over heads
# -> [B, 1, S, D//2]
sin = sin.unsqueeze(1)
cos = cos.unsqueeze(1)
# Interleave to full dim
sin = torch.repeat_interleave(sin, 2, dim=-1)
cos = torch.repeat_interleave(cos, 2, dim=-1)
# Apply RoPE
q_rot = (q * cos) + (rotate_half(q) * sin)
k_rot = (k * cos) + (rotate_half(k) * sin)
return q_rot, k_rot
class OpTest(BaseOperatorTest):
"""Test Operator Graph"""
def __init__(self):
super().__init__("Graph")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(
self,
hidden_states,
pos_ids,
nhead,
nkvhead,
dim,
past_seqlen,
nlayers,
k_cache,
v_cache,
q_proj_w,
k_proj_w,
v_proj_w,
o_proj_w,
norm_w,
sin_table,
cos_table,
**kwargs,
):
B, S, D = hidden_states.shape
for layer in range(nlayers):
# ---- RMSNorm ----
var = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(var + 1e-5) * norm_w
# ---- QKV projection ----
q = hidden_states @ q_proj_w.T
k = hidden_states @ k_proj_w.T
v = hidden_states @ v_proj_w.T
q = q.view(B, S, nhead, dim).transpose(1, 2) # [B,H,S,Dh]
k = k.view(B, S, nkvhead, dim).transpose(1, 2)
v = v.view(B, S, nkvhead, dim).transpose(1, 2)
# ---- RoPE ----
q, k = torch_rope(
q,
k,
sin_table,
cos_table,
pos_ids,
)
# ---- KV cache update ----
k_cache[layer, :, :, past_seqlen : past_seqlen + S, :] = k
v_cache[layer, :, :, past_seqlen : past_seqlen + S, :] = v
K = k_cache[layer, :, :, 0 : past_seqlen + S, :]
V = v_cache[layer, :, :, 0 : past_seqlen + S, :]
# ---- Scaled Dot Product Attention (fused) ----
def scaled_dot_product_attention(
query, key, value, is_causal=False, enable_gqa=False
) -> torch.Tensor:
S, L = query.size(-2), key.size(-2)
scale_factor = query.size(-1) ** -0.5
attn_bias = torch.zeros(S, L, dtype=query.dtype, device=query.device)
if is_causal:
mask = torch.tril(attn_bias + 1, diagonal=-1).flip(dims=[-2, -1])
attn_bias = torch.where(mask == 1, -torch.inf, 0.0)
if enable_gqa:
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
value = value.repeat_interleave(
query.size(-3) // value.size(-3), -3
)
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
return attn_weight @ value
attn_out = scaled_dot_product_attention(
q,
K,
V,
is_causal=True,
enable_gqa=True,
) # [B,H,S,Dh]
# ---- Output projection ----
attn_out = attn_out.transpose(1, 2).contiguous()
attn_out = attn_out.view(B, S, nhead * dim)
hidden_states = attn_out @ o_proj_w.T
return hidden_states
def infinicore_operator(
self,
hidden_states,
pos_ids,
nhead,
nkvhead,
dim,
past_seqlen,
nlayers,
k_cache,
v_cache,
q_proj_w,
k_proj_w,
v_proj_w,
o_proj_w,
norm_w,
sin_table,
cos_table,
**kwargs,
):
"""Record graph and run"""
input_hidden_states = hidden_states
B, S, D = input_hidden_states.shape
infinicore.start_graph_recording()
for layer in range(nlayers):
hidden_states = infinicore.nn.functional.rms_norm(
hidden_states, norm_w.shape, norm_w, 1e-5
)
q = infinicore.nn.functional.linear(hidden_states, q_proj_w)
k = infinicore.nn.functional.linear(hidden_states, k_proj_w)
v = infinicore.nn.functional.linear(hidden_states, v_proj_w)
q = q.view((B, S, nhead, dim))
k = k.view((B, S, nkvhead, dim))
v = v.view((B, S, nkvhead, dim))
q = infinicore.nn.functional.rope(
q,
pos_ids,
sin_table,
cos_table,
infinicore.nn.functional.RopeAlgo.GPT_J,
)
k = infinicore.nn.functional.rope(
k,
pos_ids,
sin_table,
cos_table,
infinicore.nn.functional.RopeAlgo.GPT_J,
)
# [B, KVH, total_len, D]
full_k = (
k_cache.narrow(0, layer, 1).squeeze(0).narrow(2, 0, past_seqlen + S)
)
full_v = (
v_cache.narrow(0, layer, 1).squeeze(0).narrow(2, 0, past_seqlen + S)
)
full_k.narrow(2, past_seqlen, S).copy_(k.permute((0, 2, 1, 3)))
full_v.narrow(2, past_seqlen, S).copy_(v.permute((0, 2, 1, 3)))
G = nhead // nkvhead
L = past_seqlen + S
full_q = (
q.permute((0, 2, 1, 3)).contiguous().view((B * nkvhead, G * S, dim))
)
full_k = full_k.view((B * nkvhead, L, dim))
full_v = full_v.view((B * nkvhead, L, dim))
attn_score = infinicore.matmul(
full_q, full_k.permute((0, 2, 1)), alpha=dim**-0.5
)
# [B * H, S, total_len]
attn_score = attn_score.view((B * nhead, S, L))
infinicore.nn.functional.causal_softmax(attn_score, out=attn_score)
attn_out = infinicore.matmul(attn_score, full_v)
attn_out = (
attn_out.view((B, nhead, S, dim))
.permute((0, 2, 1, 3))
.contiguous()
.view((B, S, nhead * dim))
)
hidden_states = infinicore.nn.functional.linear(attn_out, o_proj_w)
op_graph = infinicore.stop_graph_recording()
op_graph.run()
return hidden_states
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
"""
Test if embedding supports CUDA Graph recording
Usage:
python test/infinicore/nn/test_embedding_graph_recording.py
Key verification points:
1. Before modification: indices->to(cpu_device) triggers synchronous D2H copy, causing graph recording to fail
2. After modification: Uses device-side CUDA kernel, fully asynchronous, supports graph recording
Expected results:
- Before modification: Graph recording fails, device-side input may fail
- After modification: Graph recording succeeds, device-side input succeeds
"""
import infinicore
import torch
def test_embedding_graph_recording():
"""Test if embedding supports CUDA Graph recording"""
print("=" * 60)
print("Testing Embedding Graph Recording Support")
print("=" * 60)
# Check if CUDA is available
if not torch.cuda.is_available():
print("⚠ CUDA not available, skipping graph recording test")
return False
device = infinicore.device("cuda", 0)
# Create embedding module
vocab_size = 1000
embedding_dim = 128
embedding = infinicore.nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
dtype=infinicore.float32,
device=device,
)
# Create device-side input_ids (key point: unsupported before modification, supported after)
batch_size = 4
seq_len = 32
input_ids_device = infinicore.from_list(
[[i % vocab_size for i in range(seq_len)] for _ in range(batch_size)],
dtype=infinicore.int64,
device=device,
)
print(f"\n1. Input tensor information:")
print(f" - Shape: {input_ids_device.shape}")
print(f" - Device: {input_ids_device.device.type}")
print(f" - Dtype: {input_ids_device.dtype}")
# Attempt CUDA Graph recording
print(f"\n2. Attempting CUDA Graph recording...")
# Use PyTorch's CUDA Graph API for testing (simpler and more reliable)
try:
# Set device
infinicore.set_device(device)
# Use PyTorch's CUDA Graph API
# Note: PyTorch 2.0+ supports torch.cuda.graph
try:
# Method 1: Use PyTorch CUDA Graph (recommended)
print(" Using PyTorch CUDA Graph API for testing...")
# Create warmup input
warmup_input = input_ids_device
# Warmup (need to execute once before graph recording, including memory allocation)
embedding.forward(warmup_input)
infinicore.sync_stream() # Synchronize to ensure warmup completes
# Pre-allocate output tensor (CUDA Graph doesn't support dynamic memory allocation)
# Output shape: input_shape + [embedding_dim]
output_shape = list(input_ids_device.shape) + [embedding_dim]
output = infinicore.empty(
output_shape, dtype=embedding.weight.dtype, device=device
)
# Warmup embedding (ensure memory allocation is complete)
import infinicore.nn.functional as F
F.embedding(warmup_input, embedding.weight, out=output)
infinicore.sync_stream()
# Start graph recording (using pre-allocated output)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
# Use embedding's out parameter (in-place), passing pre-allocated output
F.embedding(input_ids_device, embedding.weight, out=output)
print(" ✓ Graph recording successful!")
print(" ✓ Embedding supports CUDA Graph recording")
# Verify graph can be replayed
graph.replay()
infinicore.sync_stream()
print(" ✓ Graph can be successfully replayed")
return True
except AttributeError:
# PyTorch version may not support torch.cuda.graph
print(
" ⚠ PyTorch version doesn't support torch.cuda.graph, using simplified verification method"
)
return test_embedding_async_verification(embedding, input_ids_device)
except RuntimeError as e:
error_msg = str(e)
if "capture" in error_msg.lower() or "graph" in error_msg.lower():
print(f" ✗ Graph recording failed: {e}")
print(
" ✗ Embedding doesn't support CUDA Graph recording (may contain synchronous operations)"
)
return False
else:
print(f" ⚠ Graph recording test exception: {e}")
return test_embedding_async_verification(embedding, input_ids_device)
except Exception as e:
print(f" ⚠ Graph recording test exception: {e}")
print(" Using simplified verification method...")
import traceback
traceback.print_exc()
return test_embedding_async_verification(embedding, input_ids_device)
def test_embedding_async_verification(embedding, input_ids_device):
"""
Simplified verification: Check if there are synchronous operations
Key checkpoints:
1. Whether input can be on device (needed CPU before modification, supports device after)
2. Whether operations are fully asynchronous (no synchronization points)
"""
print("\n3. Simplified verification: Checking asynchronous operation support")
# Verification 1: Input can be on device
if input_ids_device.device.type != "cuda":
print(" ✗ Input not on device, cannot verify")
return False
print(" ✓ Input is on device")
# Verification 2: Execute forward, check for synchronous operations
# Before modification, this would call indices->to(cpu_device), triggering synchronization
# After modification, directly uses device-side kernel, fully asynchronous
try:
# Record start time
start_event = infinicore.DeviceEvent(enable_timing=True)
end_event = infinicore.DeviceEvent(enable_timing=True)
start_event.record()
output = embedding.forward(input_ids_device)
end_event.record()
# Don't synchronize immediately, check if operation is asynchronous
# If operation is asynchronous, query should return False (not completed)
# If operation is synchronous, may have already completed
# Wait a short time
import time
time.sleep(0.001) # 1ms
# Check event status
is_complete = end_event.query()
if not is_complete:
print(" ✓ Operation is asynchronous (event not immediately completed)")
else:
print(
" ⚠ Operation may contain synchronization points (event immediately completed)"
)
# Synchronize and measure time
end_event.synchronize()
elapsed = start_event.elapsed_time(end_event)
print(f" ✓ Forward execution time: {elapsed:.3f} ms")
print(f" ✓ Output shape: {output.shape}")
print(f" ✓ Output device: {output.device.type}")
# Verify output correctness
embedding_dim = embedding.embedding_dim()
expected_shape = (*input_ids_device.shape, embedding_dim)
if output.device.type == "cuda" and output.shape == expected_shape:
print(" ✓ Output on device, shape correct")
return True
else:
print(f" ✗ Output verification failed")
print(
f" Expected shape: {expected_shape}, actual shape: {output.shape}"
)
print(f" Expected device: cuda, actual device: {output.device.type}")
return False
except Exception as e:
print(f" ✗ Verification failed: {e}")
import traceback
traceback.print_exc()
return False
def test_embedding_device_input_support():
"""Test if embedding supports device-side input"""
print("\n" + "=" * 60)
print("Testing Embedding Device-side Input Support")
print("=" * 60)
if not torch.cuda.is_available():
print("⚠ CUDA not available, skipping test")
return False
device = infinicore.device("cuda", 0)
vocab_size = 100
embedding_dim = 64
embedding = infinicore.nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
dtype=infinicore.float32,
device=device,
)
# Test 1: Device-side input (supported after modification)
print("\nTest 1: Device-side input")
try:
input_ids_device = infinicore.from_list(
[[1, 2, 3, 4, 5]], dtype=infinicore.int64, device=device
)
output = embedding.forward(input_ids_device)
print(f" ✓ Device-side input successful")
print(f" - Input device: {input_ids_device.device.type}")
print(f" - Output device: {output.device.type}")
print(f" - Output shape: {output.shape}")
return True
except Exception as e:
print(f" ✗ Device-side input failed: {e}")
return False
def main():
"""Main test function"""
print("\n" + "=" * 60)
print("Embedding Graph Recording Support Verification")
print("=" * 60)
results = []
# Test 1: Graph recording support
result1 = test_embedding_graph_recording()
results.append(("CUDA Graph Recording", result1))
# Test 2: Device-side input support
result2 = test_embedding_device_input_support()
results.append(("Device-side Input", result2))
# Summary
print("\n" + "=" * 60)
print("Test Results Summary")
print("=" * 60)
all_passed = True
for test_name, result in results:
status = "✓ Passed" if result else "✗ Failed"
print(f"{test_name}: {status}")
if not result:
all_passed = False
print("\n" + "=" * 60)
if all_passed:
print("✓ All tests passed! Embedding supports graph recording")
else:
print("✗ Some tests failed, embedding may not fully support graph recording")
print("=" * 60)
return all_passed
if __name__ == "__main__":
success = main()
exit(0 if success else 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment