Unverified Commit f5e6d729 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Issue/497 - Enhanced Test Framework (#520)

* issue/497 - add dtype __eq__ and __hash__

* issue/497 - simplified infinicore test functions

* issue/497 - improved test framework

greatly reduced the code required for specific operators;
added strided tensor support;

* issue/497 - add add interface to assist test

* issue/497 - generalized test framework based on add

* issue/497 - support non-contiguous tensors in result comparison

* issue/497 - temporarily fixed strided tensor creation

* issue/497 - rms norm interface

* issue/497 - now requires test function definition

* issue/497 - support mixed dtype

* issue/497 - initial rms norm test

* issue/497 - unified in place and out of place tests

* issue/497 - renamed src/infinicore/op

* issue/497 - reduced comments

* issue/497 - attention

* issue/497 - removed generic parameter mapping

* issue/497 - temporary attention test

* issue/497 - captitalize op name initial

* issue/497 - add a script to run all op tests

* issue/497 - fix comments

* issue/497 - simplified infinicore tensor creation from torch

* issue/497 - support tensor init modes

* issue/497 - support tensor from/to files

* issue/497 - adjust naming
parent 37c76a90
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/op/common/cache.hpp"
#include "infinicore/op/matmul.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/matmul.hpp"
#include <infiniop.h>
namespace infinicore::op::matmul_impl::infiniop {
......@@ -27,7 +27,9 @@ void calculate(Tensor c, Tensor a, Tensor b) {
infiniopGemmDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor(context::getInfiniopHandle(), &desc, c->desc(), a->desc(), b->desc()));
INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor(
context::getInfiniopHandle(), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
......
#include "infinicore/op/ones.hpp"
#include "infinicore/ops/ones.hpp"
namespace infinicore::op {
......
#include "infinicore/op/rearrange.hpp"
#include "infinicore/ops/rearrange.hpp"
namespace infinicore::op {
......
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/op/common/cache.hpp"
#include "infinicore/op/rearrange.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/rearrange.hpp"
#include <infiniop.h>
namespace infinicore::op::rearrange_impl::infiniop {
......
#include "infinicore/ops/rms_norm.hpp"
namespace infinicore::op {
common::OpDispatcher<RMSNorm::schema> &RMSNorm::dispatcher() {
static common::OpDispatcher<RMSNorm::schema> dispatcher_;
return dispatcher_;
};
void RMSNorm::execute(Tensor y, Tensor x, Tensor weight, float epsilon) {
dispatcher().lookup(context::getDevice().getType())(y, x, weight, epsilon);
}
Tensor rms_norm(Tensor x, Tensor weight, float epsilon) {
auto y = Tensor::empty(x->shape(), x->dtype(), x->device());
rms_norm_(y, x, weight, epsilon);
return y;
}
void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon) {
RMSNorm::execute(y, x, weight, epsilon);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/rms_norm.hpp"
#include <infiniop.h>
namespace infinicore::op::rms_norm_impl::infiniop {
thread_local common::OpCache<size_t, infiniopRMSNormDescriptor_t> caches(
100, // capacity
[](infiniopRMSNormDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyRMSNormDescriptor(desc));
desc = nullptr;
}
});
void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) {
size_t seed = hash_combine(y, x, weight, epsilon);
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(seed);
infiniopRMSNormDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor(
context::getInfiniopHandle(), &desc,
y->desc(), x->desc(), weight->desc(), epsilon));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetRMSNormWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopRMSNorm(
desc, workspace->data(), workspace_size,
y->data(), x->data(), weight->data(), context::getStream()));
}
static bool registered = []() {
RMSNorm::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::rms_norm_impl::infiniop
......@@ -4,7 +4,7 @@
#include "context.hpp"
#include "device.hpp"
#include "dtype.hpp"
#include "op.hpp"
#include "ops.hpp"
#include "tensor.hpp"
namespace infinicore {
......@@ -13,7 +13,7 @@ PYBIND11_MODULE(_infinicore, m) {
context::bind(m);
device::bind(m);
dtype::bind(m);
op::bind(m);
ops::bind(m);
tensor::bind(m);
}
......
......@@ -2,16 +2,22 @@
#include <pybind11/pybind11.h>
#include "op/matmul.hpp"
#include "op/rearrange.hpp"
#include "ops/add.hpp"
#include "ops/attention.hpp"
#include "ops/matmul.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
namespace py = pybind11;
namespace infinicore::op {
namespace infinicore::ops {
inline void bind(py::module &m) {
bind_add(m);
bind_attention(m);
bind_matmul(m);
bind_rearrange(m);
bind_rms_norm(m);
}
} // namespace infinicore::op
} // namespace infinicore::ops
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/add.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_add(py::module &m) {
m.def("add",
&op::add,
py::arg("a"),
py::arg("b"),
R"doc(Addition of two tensors.)doc");
m.def("add_",
&op::add_,
py::arg("c"),
py::arg("a"),
py::arg("b"),
R"doc(In-place tensor addition.)doc");
}
} // namespace infinicore::ops
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/attention.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_attention(py::module &m) {
m.def("attention",
&op::attention,
py::arg("q"),
py::arg("k"),
py::arg("v"),
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("pos"),
R"doc(Attention mechanism with KV caching.
Args:
q: Query tensor
k: Key tensor
v: Value tensor
k_cache: Key cache tensor
v_cache: Value cache tensor
pos: Current position in the sequence
Returns:
Output tensor from attention computation
)doc");
m.def("attention_",
&op::Attention::execute,
py::arg("out"),
py::arg("q"),
py::arg("k"),
py::arg("v"),
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("pos"),
R"doc(In-place attention mechanism with KV caching.
Args:
out: Output tensor
q: Query tensor
k: Key tensor
v: Value tensor
k_cache: Key cache tensor
v_cache: Value cache tensor
pos: Current position in the sequence
)doc");
}
} // namespace infinicore::ops
......@@ -2,11 +2,11 @@
#include <pybind11/pybind11.h>
#include "infinicore/op/matmul.hpp"
#include "infinicore/ops/matmul.hpp"
namespace py = pybind11;
namespace infinicore::op {
namespace infinicore::ops {
inline void bind_matmul(py::module &m) {
m.def("matmul",
......@@ -23,4 +23,4 @@ inline void bind_matmul(py::module &m) {
R"doc(In-place matrix multiplication.)doc");
}
} // namespace infinicore::op
} // namespace infinicore::ops
......@@ -2,11 +2,11 @@
#include <pybind11/pybind11.h>
#include "infinicore/op/rearrange.hpp"
#include "infinicore/ops/rearrange.hpp"
namespace py = pybind11;
namespace infinicore::op {
namespace infinicore::ops {
inline void bind_rearrange(py::module &m) {
m.def("rearrange",
......@@ -21,4 +21,4 @@ inline void bind_rearrange(py::module &m) {
R"doc(In-place tensor rearrangement.)doc");
}
} // namespace infinicore::op
} // namespace infinicore::ops
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/rms_norm.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_rms_norm(py::module &m) {
m.def("rms_norm",
&op::rms_norm,
py::arg("x"),
py::arg("weight"),
py::arg("epsilon") = 1e-5f,
R"doc(Root Mean Square Normalization.
Args:
x: Input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
Returns:
Normalized tensor with same shape as input
)doc");
m.def("rms_norm_",
&op::rms_norm_,
py::arg("y"),
py::arg("x"),
py::arg("weight"),
py::arg("epsilon") = 1e-5f,
R"doc(In-place Root Mean Square Normalization.
Args:
y: Output tensor
x: Input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
)doc");
}
} // namespace infinicore::ops
from .base import TestConfig, TestRunner, TestCase
# [file name]: __init__.py
# [file content begin]
from .base import TestConfig, TestRunner, TestCase, BaseOperatorTest
from .tensor import TensorSpec, TensorInitializer
from .utils import (
create_infinicore_tensor,
compare_results,
create_test_comparator,
debug,
get_tolerance,
infinicore_tensor_from_torch,
profile_operation,
rearrange_tensor,
convert_infinicore_to_torch,
)
from .config import get_test_devices, get_args
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .runner import GenericTestRunner
from .templates import BinaryOperatorTest, UnaryOperatorTest
__all__ = [
"TensorSpec",
"TensorInitializer",
"TestConfig",
"TestRunner",
"TestCase",
"create_infinicore_tensor",
"BaseOperatorTest",
"compare_results",
"create_test_comparator",
"convert_infinicore_to_torch",
"debug",
"get_args",
"get_test_devices",
"get_tolerance",
"infinicore_tensor_from_torch",
"profile_operation",
"rearrange_tensor",
"get_test_devices",
"get_args",
"InfiniDeviceEnum",
"InfiniDeviceNames",
"torch_device_map",
"to_torch_dtype",
"to_infinicore_dtype",
"GenericTestRunner",
"BinaryOperatorTest",
"UnaryOperatorTest",
]
import torch
import infinicore
from .devices import InfiniDeviceNames
from .utils import synchronize_device
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Tuple, Union, Callable, Optional
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map
from .tensor import TensorSpec, TensorInitializer
from .utils import (
create_test_comparator,
infinicore_tensor_from_torch,
profile_operation,
rearrange_tensor,
synchronize_device,
)
class TestCase:
"""Base test case class"""
"""Test case"""
OUT_OF_PLACE = "out_of_place"
IN_PLACE = "in_place"
BOTH = "both"
def __init__(self, operation_mode, inputs, output=None, **kwargs):
if operation_mode not in [self.IN_PLACE, self.OUT_OF_PLACE, self.BOTH]:
raise ValueError(f"Invalid operation_mode: {operation_mode}")
if operation_mode == self.IN_PLACE and output is None:
raise ValueError("IN_PLACE mode requires output specification")
self.operation_mode = operation_mode
self.inputs = []
for inp in inputs:
if isinstance(inp, (list, tuple)):
self.inputs.append(TensorSpec.from_tensor(inp))
elif isinstance(inp, TensorSpec):
self.inputs.append(inp)
else:
self.inputs.append(inp)
if isinstance(output, (list, tuple)):
self.output = TensorSpec.from_tensor(output)
else:
self.output = output
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.description = kwargs.pop("description", "")
def __str__(self):
return f"TestCase{self.args}"
mode_str = self.operation_mode.upper()
input_strs = []
for inp in self.inputs:
if hasattr(inp, "is_scalar") and inp.is_scalar:
dtype_str = f", dtype={inp.dtype}" if inp.dtype else ""
input_strs.append(f"scalar({inp.value}{dtype_str})")
elif hasattr(inp, "shape"):
dtype_str = f", dtype={inp.dtype}" if inp.dtype else ""
init_str = (
f", init={inp.init_mode}"
if inp.init_mode != TensorInitializer.RANDOM
else ""
)
if hasattr(inp, "is_contiguous") and not inp.is_contiguous:
input_strs.append(f"strided_tensor{inp.shape}{dtype_str}{init_str}")
else:
input_strs.append(f"tensor{inp.shape}{dtype_str}{init_str}")
else:
input_strs.append(str(inp))
base_str = f"TestCase(mode={mode_str}, inputs=[{', '.join(input_strs)}]"
if self.output:
dtype_str = f", dtype={self.output.dtype}" if self.output.dtype else ""
init_str = (
f", init={self.output.init_mode}"
if self.output.init_mode != TensorInitializer.RANDOM
else ""
)
base_str += f", output=tensor{self.output.shape}{dtype_str}{init_str}"
if self.kwargs:
base_str += f", kwargs={self.kwargs}"
if self.description:
base_str += f", desc='{self.description}'"
base_str += ")"
return base_str
class TestConfig:
......@@ -26,6 +98,7 @@ class TestConfig:
bench=False,
num_prerun=10,
num_iterations=1000,
dtype_combinations=None,
):
self.tensor_dtypes = tensor_dtypes
self.tolerance_map = tolerance_map
......@@ -33,6 +106,7 @@ class TestConfig:
self.bench = bench
self.num_prerun = num_prerun
self.num_iterations = num_iterations
self.dtype_combinations = dtype_combinations
class TestRunner:
......@@ -41,45 +115,61 @@ class TestRunner:
def __init__(self, test_cases, test_config):
self.test_cases = test_cases
self.config = test_config
self.failed_tests = [] # Track failures
self.failed_tests = []
def run_tests(self, devices, test_func):
"""Run tests and track failures"""
def run_tests(self, devices, test_func, test_type="Test"):
for device in devices:
print(f"\n{'='*60}")
print(f"Testing on {InfiniDeviceNames[device]}")
print(f"Testing {test_type} on {InfiniDeviceNames[device]}")
print(f"{'='*60}")
# filter unsupported data types
tensor_dtypes = self._filter_tensor_dtypes_by_device(
device, self.config.tensor_dtypes
)
for test_case in self.test_cases:
for dtype in tensor_dtypes:
try:
test_func(device, test_case, dtype, self.config)
print(f"✓ {test_case} with {dtype} passed")
except Exception as e:
error_msg = f"{test_case} with {dtype} on {InfiniDeviceNames[device]}: {e}"
print(f"✗ {error_msg}")
self.failed_tests.append(error_msg)
if self.config.debug:
raise
# Return whether any tests failed
if self.config.dtype_combinations:
for dtype_combo in self.config.dtype_combinations:
try:
test_func(device, test_case, dtype_combo, self.config)
combo_str = self._format_dtype_combo(dtype_combo)
print(f"✓ {test_case} with {combo_str} passed")
except Exception as e:
combo_str = self._format_dtype_combo(dtype_combo)
error_msg = f"{test_case} with {combo_str} on {InfiniDeviceNames[device]}: {e}"
print(f"✗ {error_msg}")
self.failed_tests.append(error_msg)
if self.config.debug:
raise
else:
for dtype in tensor_dtypes:
try:
test_func(device, test_case, dtype, self.config)
print(f"✓ {test_case} with {dtype} passed")
except Exception as e:
error_msg = f"{test_case} with {dtype} on {InfiniDeviceNames[device]}: {e}"
print(f"✗ {error_msg}")
self.failed_tests.append(error_msg)
if self.config.debug:
raise
return len(self.failed_tests) == 0
def _format_dtype_combo(self, dtype_combo):
if isinstance(dtype_combo, dict):
return f"dtypes({dtype_combo})"
elif isinstance(dtype_combo, (list, tuple)):
return f"dtypes{tuple(dtype_combo)}"
else:
return str(dtype_combo)
def _filter_tensor_dtypes_by_device(self, device, tensor_dtypes):
"""Filter data types based on device"""
if device in ():
# Filter out unsupported data types on specified devices
return [dt for dt in tensor_dtypes if dt != infinicore.bfloat16]
else:
return tensor_dtypes
def print_summary(self):
"""Print test summary"""
if self.failed_tests:
print(f"\n\033[91m{len(self.failed_tests)} tests failed:\033[0m")
for failure in self.failed_tests:
......@@ -88,3 +178,246 @@ class TestRunner:
else:
print("\n\033[92mAll tests passed!\033[0m")
return True
class BaseOperatorTest(ABC):
"""Base operator test"""
def __init__(self, operator_name):
self.operator_name = operator_name
self.test_cases = self.get_test_cases()
self.tensor_dtypes = self.get_tensor_dtypes()
self.tolerance_map = self.get_tolerance_map()
self.dtype_combinations = self.get_dtype_combinations()
@abstractmethod
def get_test_cases(self):
"""Return list of TestCase objects"""
pass
@abstractmethod
def get_tensor_dtypes(self):
"""Return supported data types"""
pass
@abstractmethod
def get_tolerance_map(self):
"""Return tolerance configuration"""
pass
def get_dtype_combinations(self):
"""Return dtype combinations for mixed dtype tests"""
return None
@abstractmethod
def torch_operator(self, *inputs, out=None, **kwargs):
"""Unified PyTorch operator function"""
pass
@abstractmethod
def infinicore_operator(self, *inputs, out=None, **kwargs):
"""Unified Infinicore operator function"""
pass
def create_strided_tensor(
self, shape, strides, dtype, device, init_mode=TensorInitializer.RANDOM
):
"""Create a non-contiguous tensor with specific strides"""
spec = TensorSpec.from_strided_tensor(shape, strides, dtype, init_mode)
return spec.create_torch_tensor(device, dtype)
def prepare_inputs(self, test_case, device, dtype_config):
"""Prepare input data"""
inputs = []
for i, input_spec in enumerate(test_case.inputs):
if isinstance(input_spec, TensorSpec):
if input_spec.is_scalar:
inputs.append(input_spec.value)
else:
tensor = input_spec.create_torch_tensor(device, dtype_config, i)
inputs.append(tensor)
else:
inputs.append(input_spec)
return inputs, test_case.kwargs
def get_output_dtype(self, test_case, dtype_config, torch_result=None):
"""Determine output dtype - returns infinicore dtype, not torch dtype"""
if test_case.output and test_case.output.dtype is not None:
return test_case.output.dtype
elif isinstance(dtype_config, dict) and "output" in dtype_config:
return dtype_config["output"]
elif torch_result is not None:
return to_infinicore_dtype(torch_result.dtype)
else:
if isinstance(dtype_config, (list, tuple)):
return dtype_config[0]
else:
return dtype_config
def run_test(self, device, test_case, dtype_config, config):
"""Unified test execution flow"""
device_str = torch_device_map[device]
if test_case.operation_mode == TestCase.BOTH:
out_of_place_case = TestCase(
TestCase.OUT_OF_PLACE,
test_case.inputs,
test_case.output,
**test_case.kwargs,
)
self._run_single_test(
device, out_of_place_case, dtype_config, config, "OUT_OF_PLACE"
)
if test_case.output is not None:
in_place_case = TestCase(
TestCase.IN_PLACE,
test_case.inputs,
test_case.output,
**test_case.kwargs,
)
self._run_single_test(
device, in_place_case, dtype_config, config, "IN_PLACE"
)
return
self._run_single_test(
device, test_case, dtype_config, config, test_case.operation_mode.upper()
)
def _run_single_test(self, device, test_case, dtype_config, config, mode_name):
"""Run a single test with specified operation mode"""
device_str = torch_device_map[device]
inputs, kwargs = self.prepare_inputs(test_case, device, dtype_config)
infini_inputs = []
for inp in inputs:
if isinstance(inp, torch.Tensor):
infini_tensor = infinicore_tensor_from_torch(inp)
infini_inputs.append(infini_tensor)
else:
infini_inputs.append(inp)
if test_case.operation_mode == TestCase.OUT_OF_PLACE:
def torch_op():
return self.torch_operator(*inputs, **kwargs)
torch_result = torch_op()
if (
isinstance(torch_result, torch.Tensor)
and not torch_result.is_contiguous()
):
torch_result = torch_result.contiguous()
def infini_op():
return self.infinicore_operator(*infini_inputs, **kwargs)
infini_result = infini_op()
# Get comparison dtype (infinicore dtype)
comparison_dtype = self.get_output_dtype(
test_case, dtype_config, torch_result
)
compare_fn = create_test_comparator(
config, comparison_dtype, mode_name=f"{self.operator_name} {mode_name}"
)
is_valid = compare_fn(infini_result, torch_result)
assert is_valid, f"{self.operator_name} {mode_name} test failed"
if config.bench:
profile_operation(
f"PyTorch {self.operator_name} {mode_name}",
torch_op,
device_str,
config.num_prerun,
config.num_iterations,
)
profile_operation(
f"Infinicore {self.operator_name} {mode_name}",
infini_op,
device_str,
config.num_prerun,
config.num_iterations,
)
else:
if not test_case.output:
raise ValueError("IN_PLACE test requires output specification")
# Get output dtype and create output tensor
output_dtype = self.get_output_dtype(test_case, dtype_config)
output_shape = test_case.output.shape
# Use TensorSpec to create output tensor with specified initialization mode
if test_case.output.is_contiguous or test_case.output.strides is None:
output_spec = TensorSpec.from_tensor(
output_shape, output_dtype, init_mode=test_case.output.init_mode
)
else:
output_spec = TensorSpec.from_strided_tensor(
output_shape,
test_case.output.strides,
output_dtype,
init_mode=test_case.output.init_mode,
)
torch_output = output_spec.create_torch_tensor(device, output_dtype)
# For non-contiguous tensors, we need to ensure zeros initialization
if (
not test_case.output.is_contiguous
and test_case.output.strides is not None
):
torch_output.zero_()
def torch_op_inplace():
self.torch_operator(*inputs, out=torch_output, **kwargs)
torch_op_inplace()
# Create infinicore output tensor
torch_dummy = torch.zeros(
output_shape, dtype=to_torch_dtype(output_dtype), device=device_str
)
if (
not test_case.output.is_contiguous
and not test_case.output.strides is None
):
rearrange_tensor(torch_dummy, list(torch_output.stride()))
infini_output = infinicore_tensor_from_torch(torch_dummy)
def infini_op_inplace():
self.infinicore_operator(*infini_inputs, out=infini_output, **kwargs)
infini_op_inplace()
comparison_dtype = self.get_output_dtype(
test_case, dtype_config, torch_output
)
compare_fn = create_test_comparator(
config, comparison_dtype, mode_name=f"{self.operator_name} {mode_name}"
)
is_valid = compare_fn(infini_output, torch_output)
assert is_valid, f"{self.operator_name} {mode_name} test failed"
if config.bench:
profile_operation(
f"PyTorch {self.operator_name} {mode_name}",
torch_op_inplace,
device_str,
config.num_prerun,
config.num_iterations,
)
profile_operation(
f"Infinicore {self.operator_name} {mode_name}",
infini_op_inplace,
device_str,
config.num_prerun,
config.num_iterations,
)
"""
Generic test runner that handles the common execution flow for all operators
"""
import sys
from . import TestConfig, TestRunner, get_args, get_test_devices
class GenericTestRunner:
"""Generic test runner that handles the common execution flow"""
def __init__(self, operator_test_class):
"""
Args:
operator_test_class: A class that implements BaseOperatorTest interface
"""
self.operator_test = operator_test_class()
self.args = get_args()
def run(self):
"""Execute the complete test suite"""
config = TestConfig(
tensor_dtypes=self.operator_test.tensor_dtypes,
tolerance_map=self.operator_test.tolerance_map,
debug=self.args.debug,
bench=self.args.bench,
num_prerun=self.args.num_prerun,
num_iterations=self.args.num_iterations,
dtype_combinations=self.operator_test.dtype_combinations,
)
runner = TestRunner(self.operator_test.test_cases, config)
devices = get_test_devices(self.args)
# Run unified tests
all_passed = runner.run_tests(
devices, self.operator_test.run_test, self.operator_test.operator_name
)
# Print summary
summary_passed = runner.print_summary()
all_passed = all_passed and summary_passed
return all_passed
def run_and_exit(self):
"""Run tests and exit with appropriate status code"""
success = self.run()
sys.exit(0 if success else 1)
"""
Templates for common operator patterns to minimize code duplication
Available configuration methods in BaseOperatorTest:
1. get_test_cases() -> List[TestCase]
- Define input/output shapes, strides, and operation modes
- Operation modes: TestCase.OUT_OF_PLACE, TestCase.IN_PLACE, TestCase.BOTH
2. get_tensor_dtypes() -> List[infinicore.dtype]
- Define supported data types for single-dtype tests
- Used when dtype_combinations is None
3. get_tolerance_map() -> Dict[infinicore.dtype, Dict[str, float]]
- Set tolerance (atol, rtol) for each data type
- Example: {infinicore.float16: {"atol": 1e-3, "rtol": 1e-2}}
4. get_dtype_combinations() -> Optional[List[Dict]]
- Define mixed dtype configurations for multi-dtype tests
- Return None for single-dtype tests
5. torch_operator(*inputs, out=None, **kwargs) -> torch.Tensor
- Implement PyTorch reference implementation
6. infinicore_operator(*inputs, out=None, **kwargs) -> infinicore.Tensor
- Implement Infinicore operator implementation
New Tensor Initialization Modes:
- TensorInitializer.RANDOM (default): Random values using torch.rand
- TensorInitializer.ZEROS: All zeros using torch.zeros
- TensorInitializer.ONES: All ones using torch.ones
- TensorInitializer.RANDINT: Random integers using torch.randint
- TensorInitializer.MANUAL: Use a pre-existing tensor with shape/strides validation
- TensorInitializer.BINARY: Use a pre-existing tensor with shape validation only
Usage examples in TestCase creation:
- Basic: TensorSpec.from_tensor(shape)
- With initialization: TensorSpec.from_tensor(shape, init_mode=TensorInitializer.ZEROS)
- Strided with custom init: TensorSpec.from_strided_tensor(shape, strides, init_mode=TensorInitializer.ONES)
"""
import torch
import infinicore
from .base import BaseOperatorTest
from .tensor import TensorSpec, TensorInitializer
class BinaryOperatorTest(BaseOperatorTest):
"""Template for binary operators (matmul, add, mul, etc.)"""
def __init__(self, operator_name, test_cases, tensor_dtypes, tolerance_map):
self._operator_name = operator_name
self._test_cases = test_cases
self._tensor_dtypes = tensor_dtypes
self._tolerance_map = tolerance_map
super().__init__(operator_name)
def get_test_cases(self):
return self._test_cases
def get_tensor_dtypes(self):
return self._tensor_dtypes
def get_tolerance_map(self):
return self._tolerance_map
def torch_operator(self, *inputs, **kwargs):
"""Generic torch operator dispatch"""
# Support both functional and method calls
if hasattr(torch, self._operator_name):
op = getattr(torch, self._operator_name)
else:
# Fallback to common operator mappings
op_mapping = {
"matmul": torch.matmul,
"add": torch.add,
"mul": torch.mul,
"sub": torch.sub,
"div": torch.div,
}
op = op_mapping.get(self._operator_name)
if op is None:
raise NotImplementedError(
f"Torch operator {self._operator_name} not implemented"
)
return op(*inputs, **kwargs)
def infinicore_operator(self, *inputs, **kwargs):
"""Generic infinicore operator dispatch"""
op = getattr(infinicore, self._operator_name)
return op(*inputs, **kwargs)
class UnaryOperatorTest(BinaryOperatorTest):
"""Template for unary operators (exp, log, sin, etc.)"""
def torch_operator(self, *inputs, **kwargs):
# For unary operators, we only use the first input
if hasattr(torch, self._operator_name):
op = getattr(torch, self._operator_name)
return op(inputs[0], **kwargs)
else:
return super().torch_operator(*inputs, **kwargs)
def infinicore_operator(self, *inputs, **kwargs):
op = getattr(infinicore, self._operator_name)
return op(inputs[0], **kwargs)
import torch
from pathlib import Path
from .datatypes import to_torch_dtype
from .devices import torch_device_map
class TensorInitializer:
"""Tensor data initializer with multiple modes"""
RANDOM = "random"
ZEROS = "zeros"
ONES = "ones"
RANDINT = "randint"
MANUAL = "manual"
BINARY = "binary"
FROM_FILE = "from_file"
@staticmethod
def create_tensor(
shape, dtype, device, mode=RANDOM, strides=None, set_tensor=None, file_path=None
):
"""
Create a torch tensor with specified initialization mode
Args:
shape: Tensor shape
dtype: infinicore dtype
device: InfiniDeviceEnum
mode: Initialization mode
strides: Optional strides for strided tensors
set_tensor: Pre-existing tensor for manual/binary mode
file_path: Path to file for FROM_FILE mode
Returns:
torch.Tensor: Initialized tensor
"""
# Convert InfiniDeviceEnum to torch device string
torch_device_str = torch_device_map[device]
torch_dtype = to_torch_dtype(dtype)
# Handle strided tensors - calculate required storage size
if strides is not None:
# Calculate the required storage size for strided tensor
storage_size = 0
for i in range(len(shape)):
if shape[i] > 0:
storage_size += (shape[i] - 1) * abs(strides[i])
storage_size += 1 # Add 1 for the base element
# Create base storage with sufficient size
if mode == TensorInitializer.RANDOM:
base_tensor = torch.rand(
storage_size, dtype=torch_dtype, device=torch_device_str
)
elif mode == TensorInitializer.ZEROS:
base_tensor = torch.zeros(
storage_size, dtype=torch_dtype, device=torch_device_str
)
elif mode == TensorInitializer.ONES:
base_tensor = torch.ones(
storage_size, dtype=torch_dtype, device=torch_device_str
)
elif mode == TensorInitializer.RANDINT:
base_tensor = torch.randint(
-2000000000,
2000000000,
(storage_size,),
dtype=torch_dtype,
device=torch_device_str,
)
elif mode == TensorInitializer.MANUAL:
assert set_tensor is not None, "Manual mode requires set_tensor"
base_tensor = set_tensor.to(torch_dtype).to(torch_device_str)
elif mode == TensorInitializer.BINARY:
assert set_tensor is not None, "Binary mode requires set_tensor"
base_tensor = set_tensor.to(torch_dtype).to(torch_device_str)
elif mode == TensorInitializer.FROM_FILE:
base_tensor = TensorInitializer._load_from_file(
file_path, storage_size, torch_dtype, torch_device_str
)
else:
raise ValueError(f"Unsupported initialization mode: {mode}")
# Create strided view
tensor = torch.as_strided(base_tensor, shape, strides)
else:
# Contiguous tensor
if mode == TensorInitializer.RANDOM:
tensor = torch.rand(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.ZEROS:
tensor = torch.zeros(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.ONES:
tensor = torch.ones(shape, dtype=torch_dtype, device=torch_device_str)
elif mode == TensorInitializer.RANDINT:
tensor = torch.randint(
-2000000000,
2000000000,
shape,
dtype=torch_dtype,
device=torch_device_str,
)
elif mode == TensorInitializer.MANUAL:
assert set_tensor is not None, "Manual mode requires set_tensor"
assert shape == list(set_tensor.shape), "Shape mismatch in manual mode"
tensor = set_tensor.to(torch_dtype).to(torch_device_str)
elif mode == TensorInitializer.BINARY:
assert set_tensor is not None, "Binary mode requires set_tensor"
assert shape == list(set_tensor.shape), "Shape mismatch in binary mode"
tensor = set_tensor.to(torch_dtype).to(torch_device_str)
elif mode == TensorInitializer.FROM_FILE:
tensor = TensorInitializer._load_from_file(
file_path, shape, torch_dtype, torch_device_str
)
else:
raise ValueError(f"Unsupported initialization mode: {mode}")
return tensor
@staticmethod
def _load_from_file(file_path, shape_or_size, torch_dtype, torch_device_str):
"""
Load tensor data from file using PyTorch's native methods
Args:
file_path: Path to the file
shape_or_size: Tensor shape for contiguous or size for strided
torch_dtype: Target torch dtype
torch_device_str: Target device string
Returns:
torch.Tensor: Tensor with data loaded from file
"""
if file_path is None:
raise ValueError("FROM_FILE mode requires file_path")
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
# Determine file type and load accordingly
file_extension = file_path.suffix.lower()
if file_extension in [".pt", ".pth"]:
# PyTorch native format
tensor = torch.load(file_path, map_location=torch_device_str)
elif file_extension in [".bin", ".dat", ".raw"]:
# Raw binary format - we need to know the expected shape
tensor = TensorInitializer._load_binary_file(
file_path, shape_or_size, torch_dtype, torch_device_str
)
elif file_extension in [".npy"]:
# NumPy format - fallback to numpy if needed
try:
import numpy as np
numpy_array = np.load(file_path)
tensor = (
torch.from_numpy(numpy_array).to(torch_dtype).to(torch_device_str)
)
except ImportError:
raise ImportError("NumPy is required to load .npy files")
else:
# Try to load as PyTorch format first, then fallback to binary
try:
tensor = torch.load(file_path, map_location=torch_device_str)
except:
# Fallback to binary loading
tensor = TensorInitializer._load_binary_file(
file_path, shape_or_size, torch_dtype, torch_device_str
)
# Ensure correct dtype and device
tensor = tensor.to(torch_dtype).to(torch_device_str)
# Validate shape/size
if isinstance(shape_or_size, (list, tuple)):
# Contiguous tensor - check shape
if list(tensor.shape) != list(shape_or_size):
raise ValueError(
f"Tensor shape mismatch: expected {shape_or_size}, got {tensor.shape}"
)
else:
# Strided tensor - check total size
if tensor.numel() != shape_or_size:
raise ValueError(
f"Tensor size mismatch: expected {shape_or_size} elements, got {tensor.numel()}"
)
return tensor
@staticmethod
def _load_binary_file(file_path, shape_or_size, torch_dtype, torch_device_str):
"""
Load tensor from raw binary file
Args:
file_path: Path to binary file
shape_or_size: Expected shape or size
torch_dtype: Target dtype
torch_device_str: Target device
Returns:
torch.Tensor: Loaded tensor
"""
# Read binary data
with open(file_path, "rb") as f:
binary_data = f.read()
# Create tensor from buffer
if isinstance(shape_or_size, (list, tuple)):
# Contiguous tensor with known shape
tensor = torch.frombuffer(binary_data, dtype=torch_dtype).reshape(
shape_or_size
)
else:
# Strided tensor - just 1D buffer
tensor = torch.frombuffer(binary_data, dtype=torch_dtype)
return tensor.to(torch_device_str)
@staticmethod
def save_to_file(tensor, file_path, format="auto"):
"""
Save tensor data to file using PyTorch's native methods
Args:
tensor: torch.Tensor to save
file_path: Path to save the file
format: File format ('auto', 'torch', 'binary', 'numpy')
"""
file_path = Path(file_path)
if format == "auto":
# Determine format from file extension
file_extension = file_path.suffix.lower()
if file_extension in [".pt", ".pth"]:
format = "torch"
elif file_extension in [".npy"]:
format = "numpy"
else:
format = "binary"
if format == "torch":
# PyTorch native format (preserves metadata)
torch.save(tensor, file_path)
elif format == "binary":
# Raw binary format
with open(file_path, "wb") as f:
f.write(tensor.cpu().numpy().tobytes())
elif format == "numpy":
# NumPy format
try:
import numpy as np
np.save(file_path, tensor.cpu().numpy())
except ImportError:
raise ImportError("NumPy is required to save .npy files")
else:
raise ValueError(f"Unsupported format: {format}")
print(
f"Tensor saved to {file_path} (shape: {tensor.shape}, dtype: {tensor.dtype}, format: {format})"
)
@staticmethod
def list_supported_formats():
"""Return list of supported file formats"""
return {
"torch": [".pt", ".pth"], # PyTorch native format
"binary": [".bin", ".dat", ".raw"], # Raw binary
"numpy": [".npy"], # NumPy format
}
class TensorSpec:
"""Tensor specification supporting various input types and per-tensor dtype"""
def __init__(
self,
shape=None,
dtype=None,
strides=None,
value=None,
is_scalar=False,
is_contiguous=True,
init_mode=TensorInitializer.RANDOM, # Default to random initialization
custom_tensor=None, # For manual/binary mode
file_path=None, # For FROM_FILE mode
file_format=None, # Optional file format hint
):
self.shape = shape
self.dtype = dtype
self.strides = strides
self.value = value
self.is_scalar = is_scalar
self.is_contiguous = is_contiguous
self.init_mode = init_mode
self.custom_tensor = custom_tensor
self.file_path = file_path
self.file_format = file_format
@classmethod
def from_tensor(
cls,
shape,
dtype=None,
strides=None,
is_contiguous=True,
init_mode=TensorInitializer.RANDOM,
custom_tensor=None,
file_path=None,
):
return cls(
shape=shape,
dtype=dtype,
strides=strides,
is_scalar=False,
is_contiguous=is_contiguous,
init_mode=init_mode,
custom_tensor=custom_tensor,
file_path=file_path,
)
@classmethod
def from_scalar(cls, value, dtype=None):
return cls(value=value, dtype=dtype, is_scalar=True)
@classmethod
def from_strided_tensor(
cls,
shape,
strides,
dtype=None,
init_mode=TensorInitializer.RANDOM,
custom_tensor=None,
file_path=None,
):
return cls(
shape=shape,
dtype=dtype,
strides=strides,
is_scalar=False,
is_contiguous=False,
init_mode=init_mode,
custom_tensor=custom_tensor,
file_path=file_path,
)
@classmethod
def from_file(
cls,
file_path,
shape,
dtype=None,
strides=None,
is_contiguous=True,
file_format=None,
):
"""
Create TensorSpec that loads data from file
Args:
file_path: Path to file
shape: Tensor shape
dtype: infinicore dtype (inferred from file if None)
strides: Optional strides for strided tensors
is_contiguous: Whether tensor is contiguous
file_format: Optional file format hint
Returns:
TensorSpec: Configured for file loading
"""
return cls(
shape=shape,
dtype=dtype,
strides=strides,
is_scalar=False,
is_contiguous=is_contiguous,
init_mode=TensorInitializer.FROM_FILE,
file_path=file_path,
file_format=file_format,
)
def create_torch_tensor(self, device, dtype_config, tensor_index=0):
"""Create a torch tensor based on this specification"""
if self.is_scalar:
return self.value
# Determine dtype - ensure we're using infinicore dtype, not torch dtype
if self.dtype is not None:
tensor_dtype = self.dtype
elif isinstance(dtype_config, dict) and f"input_{tensor_index}" in dtype_config:
tensor_dtype = dtype_config[f"input_{tensor_index}"]
elif isinstance(dtype_config, (list, tuple)) and tensor_index < len(
dtype_config
):
tensor_dtype = dtype_config[tensor_index]
else:
tensor_dtype = dtype_config
# Create tensor using the specified initialization mode
return TensorInitializer.create_tensor(
shape=self.shape,
dtype=tensor_dtype,
device=device,
mode=self.init_mode,
strides=self.strides,
set_tensor=self.custom_tensor,
file_path=self.file_path,
)
......@@ -4,18 +4,6 @@ import infinicore
from .datatypes import to_infinicore_dtype, to_torch_dtype
def create_infinicore_tensor(torch_tensor, device_str):
"""Create infinicore tensor from PyTorch tensor"""
infini_device = infinicore.device(device_str, 0)
return infinicore.from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
def synchronize_device(torch_device):
"""Device synchronization"""
if torch_device == "cuda":
......@@ -117,7 +105,6 @@ def print_discrepancy(
f"delta: {add_color(delta_str, 33)}"
)
print(add_color(" INFO:", 35))
print(f" - Actual dtype: {actual.dtype}")
print(f" - Desired dtype: {expected.dtype}")
print(f" - Atol: {atol}")
......@@ -149,44 +136,103 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
return tolerance["atol"], tolerance["rtol"]
def compare_results(
infini_result, torch_result, dtype, config, device_str, tolerance_map=None
):
def infinicore_tensor_from_torch(torch_tensor):
infini_device = infinicore.device(torch_tensor.device.type, 0)
if torch_tensor.is_contiguous():
return infinicore.from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
else:
return infinicore.strided_from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
list(torch_tensor.stride()),
dtype=to_infinicore_dtype(torch_tensor.dtype),
device=infini_device,
)
def convert_infinicore_to_torch(infini_result, torch_reference):
"""
Compare infinicore result with PyTorch reference result
Convert infinicore tensor to PyTorch tensor for comparison
Args:
infini_result: infinicore tensor result
torch_result: PyTorch tensor reference result
torch_reference: PyTorch tensor reference (for shape and device)
dtype: infinicore data type
config: test config
device_str: torch device string
device: device enum
tolerance_map: optional tolerance map (defaults to config's tolerance_map)
Returns:
bool: True if results match within tolerance
torch.Tensor: PyTorch tensor with infinicore data
"""
# Convert infinicore result to PyTorch tensor for comparison
torch_result_from_infini = torch.zeros(
torch_result.shape, dtype=to_torch_dtype(dtype), device=device_str
torch_reference.shape,
dtype=to_torch_dtype(infini_result.dtype),
device=infini_result.device.type,
)
temp_tensor = create_infinicore_tensor(torch_result_from_infini, device_str)
temp_tensor = infinicore_tensor_from_torch(torch_result_from_infini)
temp_tensor.copy_(infini_result)
return torch_result_from_infini
# Retrieve tolerance - use provided map or config's map
if tolerance_map is None:
tolerance_map = config.tolerance_map
atol, rtol = get_tolerance(tolerance_map, dtype)
def compare_results(
infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False
):
"""
Generic function to compare infinicore result with PyTorch reference result
Args:
infini_result: infinicore tensor result
torch_result: PyTorch tensor reference result
atol: absolute tolerance
rtol: relative tolerance
debug_mode: whether to enable debug output
Returns:
bool: True if results match within tolerance
"""
# Convert infinicore result to PyTorch tensor for comparison
torch_result_from_infini = convert_infinicore_to_torch(infini_result, torch_result)
# Debug mode: detailed comparison
if config.debug:
if debug_mode:
debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol)
# Check if results match within tolerance
return torch.allclose(torch_result_from_infini, torch_result, atol=atol, rtol=rtol)
def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""):
"""
Create a test-specific comparison function that handles test configuration
Args:
config: test configuration
dtype: infinicore data type
tolerance_map: optional tolerance map (defaults to config's tolerance_map)
mode_name: operation mode name for debug output
Returns:
callable: function that takes (infini_result, torch_result) and returns bool
"""
if tolerance_map is None:
tolerance_map = config.tolerance_map
atol, rtol = get_tolerance(tolerance_map, dtype)
def compare_test_results(infini_result, torch_result):
if config.debug and mode_name:
print(f"\n\033[94mDEBUG INFO - {mode_name}:\033[0m")
return compare_results(
infini_result, torch_result, atol=atol, rtol=rtol, debug_mode=config.debug
)
return compare_test_results
def rearrange_tensor(tensor, new_strides):
"""
Given a PyTorch tensor and a list of new strides, return a new PyTorch tensor with the given strides.
......
import torch
import infinicore
import sys
import os
# Framework path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from framework import (
TestConfig,
TestRunner,
TestCase,
create_infinicore_tensor,
compare_results,
get_args,
get_test_devices,
profile_operation,
to_torch_dtype,
InfiniDeviceNames,
torch_device_map,
)
# ==============================================================================
# Test Setup
# ==============================================================================
# Test cases
_TEST_CASES = [
# (a_shape, b_shape, result_shape, a_stride, b_stride, c_stride)
TestCase((2, 3), (3, 4), (2, 4), None, None, None),
TestCase((128, 256), (256, 64), (128, 64), None, None, None),
TestCase((2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
TestCase((1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
TestCase((6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
TestCase((4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
]
# Data types - now using infinicore native types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Tolerance
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 0, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
}
# ==============================================================================
# Test Method
# ==============================================================================
def test_matmul(device, test_case, dtype, config):
"""
Test matmul operation
Args:
device: device enum
test_case: test case
dtype: infinicore data type
config: test config
"""
a_shape, b_shape, result_shape, a_stride, b_stride, c_stride = test_case.args
print(
f"Testing Matmul on {InfiniDeviceNames[device]} with "
f"a_shape:{a_shape}, b_shape:{b_shape}, result_shape:{result_shape}, "
f"a_stride:{a_stride}, b_stride:{b_stride}, c_stride:{c_stride}, "
f"dtype:{dtype}"
)
# Create PyTorch tensors
device_str = torch_device_map[device]
torch_dtype = to_torch_dtype(dtype)
torch_a = torch.rand(a_shape, dtype=torch_dtype, device=device_str)
torch_b = torch.rand(b_shape, dtype=torch_dtype, device=device_str)
# Calculate PyTorch reference result
def torch_matmul():
return torch.matmul(torch_a, torch_b)
torch_result = torch_matmul()
# Create infinicore tensors
infini_a = create_infinicore_tensor(torch_a, device_str)
infini_b = create_infinicore_tensor(torch_b, device_str)
# Out-of-place matmul
def infini_matmul():
return infinicore.matmul(infini_a, infini_b)
infini_result = infini_matmul()
# Validate results using common method
is_valid = compare_results(infini_result, torch_result, dtype, config, device_str)
assert is_valid, "Matmul test failed"
# Performance test
if config.bench:
profile_operation(
"PyTorch",
torch_matmul,
device_str,
config.num_prerun,
config.num_iterations,
)
profile_operation(
"Infinicore",
infini_matmul,
device_str,
config.num_prerun,
config.num_iterations,
)
def test_matmul_inplace(device, test_case, dtype, config):
"""
Test in-place matmul operation
Args:
device: device enum
test_case: test case
dtype: infinicore data type
config: test config
"""
a_shape, b_shape, result_shape, a_stride, b_stride, c_stride = test_case.args
print(
f"Testing In-place Matmul on {InfiniDeviceNames[device]} with "
f"a_shape:{a_shape}, b_shape:{b_shape}, result_shape:{result_shape}, "
f"dtype:{dtype}"
)
device_str = torch_device_map[device]
torch_dtype = to_torch_dtype(dtype)
# Create PyTorch tensors
torch_a = torch.rand(a_shape, dtype=torch_dtype, device=device_str)
torch_b = torch.rand(b_shape, dtype=torch_dtype, device=device_str)
# Create pre-allocated result tensor
torch_preallocated = torch.zeros(result_shape, dtype=torch_dtype, device=device_str)
# Calculate PyTorch reference result using in-place operation
def torch_matmul_inplace():
torch.matmul(torch_a, torch_b, out=torch_preallocated)
# Execute in-place operation
torch_matmul_inplace()
# Create infinicore tensors
infini_a = create_infinicore_tensor(torch_a, device_str)
infini_b = create_infinicore_tensor(torch_b, device_str)
infini_c = infinicore.empty(
result_shape, dtype=dtype, device=infinicore.device(device_str, 0)
)
# Test in-place matmul
def infini_matmul_inplace():
infinicore.matmul(infini_a, infini_b, out=infini_c)
# Execute in-place operation
infini_matmul_inplace()
# Validate results using common method
is_valid = compare_results(infini_c, torch_preallocated, dtype, config, device_str)
assert is_valid, "In-place matmul test failed"
# Performance test
if config.bench:
profile_operation(
"PyTorch In-place",
torch_matmul_inplace,
device_str,
config.num_prerun,
config.num_iterations,
)
profile_operation(
"Infinicore In-place",
infini_matmul_inplace,
device_str,
config.num_prerun,
config.num_iterations,
)
# ==============================================================================
# Main Execution Function
# ==============================================================================
def main():
args = get_args()
# Create test configuration
config = TestConfig(
tensor_dtypes=_TENSOR_DTYPES,
tolerance_map=_TOLERANCE_MAP,
debug=args.debug,
bench=args.bench,
num_prerun=args.num_prerun,
num_iterations=args.num_iterations,
)
# Create test runner
runner = TestRunner(_TEST_CASES, config)
# Get test devices
devices = get_test_devices(args)
print("Starting matmul tests...")
all_passed = True
# Run out-of-place tests
print("\n--- Testing Out-of-place Matmul ---")
out_of_place_passed = runner.run_tests(devices, test_matmul)
all_passed = all_passed and out_of_place_passed
# Run in-place tests
print("\n--- Testing In-place Matmul ---")
in_place_passed = runner.run_tests(devices, test_matmul_inplace)
all_passed = all_passed and in_place_passed
runner.print_summary()
sys.exit(0 if all_passed else 1)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment