Unverified Commit f5e6d729 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Issue/497 - Enhanced Test Framework (#520)

* issue/497 - add dtype __eq__ and __hash__

* issue/497 - simplified infinicore test functions

* issue/497 - improved test framework

greatly reduced the code required for specific operators;
added strided tensor support;

* issue/497 - add add interface to assist test

* issue/497 - generalized test framework based on add

* issue/497 - support non-contiguous tensors in result comparison

* issue/497 - temporarily fixed strided tensor creation

* issue/497 - rms norm interface

* issue/497 - now requires test function definition

* issue/497 - support mixed dtype

* issue/497 - initial rms norm test

* issue/497 - unified in place and out of place tests

* issue/497 - renamed src/infinicore/op

* issue/497 - reduced comments

* issue/497 - attention

* issue/497 - removed generic parameter mapping

* issue/497 - temporary attention test

* issue/497 - captitalize op name initial

* issue/497 - add a script to run all op tests

* issue/497 - fix comments

* issue/497 - simplified infinicore tensor creation from torch

* issue/497 - support tensor init modes

* issue/497 - support tensor from/to files

* issue/497 - adjust naming
parent 37c76a90
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (operation_mode, shape, a_strides, b_strides, c_strides)
_TEST_CASES_DATA = [
(TestCase.BOTH, (13, 4), None, None, None),
(TestCase.BOTH, (13, 4), (10, 1), (10, 1), (10, 1)),
(TestCase.BOTH, (13, 4), (0, 1), None, None),
(TestCase.BOTH, (13, 4, 4), None, None, None),
(TestCase.BOTH, (13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
(TestCase.BOTH, (13, 4, 4), (4, 0, 1), (0, 4, 1), None),
(TestCase.BOTH, (16, 5632), None, None, None),
(TestCase.BOTH, (16, 5632), (13312, 1), (13312, 1), (13312, 1)),
]
def parse_test_cases(data):
"""
Parse add test case data according to format:
(operation_mode, shape, a_strides, b_strides, c_strides)
"""
operation_mode = data[0]
shape = data[1]
a_strides = data[2] if len(data) > 2 else None
b_strides = data[3] if len(data) > 3 else None
c_strides = data[4] if len(data) > 4 else None
# Create input specifications
inputs = []
# Input tensor a
if a_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(shape, a_strides))
else:
inputs.append(TensorSpec.from_tensor(shape))
# Input tensor b (same shape as a)
if b_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(shape, b_strides))
else:
inputs.append(TensorSpec.from_tensor(shape))
# Output tensor
if c_strides is not None:
output = TensorSpec.from_strided_tensor(shape, c_strides)
else:
output = TensorSpec.from_tensor(shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Tolerance
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 0, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
}
class OpTest(BaseOperatorTest):
"""Add test with simplified test case parsing"""
def __init__(self):
super().__init__("Add")
def get_test_cases(self):
return _TEST_CASES
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
def get_tolerance_map(self):
return _TOLERANCE_MAP
def torch_operator(self, a, b, out=None, **kwargs):
return torch.add(a, b, out=out)
def infinicore_operator(self, a, b, out=None, **kwargs):
return infinicore.add(a, b, out=out)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
"""
This is for framework validation
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (operation_mode, n_q_head, n_kv_head, seq_len, head_dim, pos,
# k_cache_buf_len, v_cache_buf_len, q_strides, k_strides, v_strides,
# k_cache_strides, v_cache_strides)
_TEST_CASES_DATA = [
# Prefill stage
(
TestCase.OUT_OF_PLACE,
32,
4,
5,
64,
0,
2048,
2048,
[64, 2560, 1],
[64, 2560, 1],
[64, 2560, 1],
[64, 11264, 1],
[64, 11264, 1],
),
# Decode stage
(
TestCase.OUT_OF_PLACE,
32,
4,
1,
64,
3,
2048,
2048,
[64, 2560, 1],
[64, 2560, 1],
[64, 2560, 1],
[64, 11264, 1],
[64, 11264, 1],
),
# Small test case
(TestCase.OUT_OF_PLACE, 8, 4, 2, 16, 1, 8, 8, None, None, None, None, None),
# Another prefill case
(
TestCase.OUT_OF_PLACE,
28,
28,
15,
128,
0,
2048,
2048,
[128, 10752, 1],
[128, 10752, 1],
[128, 10752, 1],
[128, 3584, 1],
[128, 3584, 1],
),
]
# Epsilon constant for causal softmax
_EPSILON = 1e-5
def causal_softmax(x):
"""Apply causal mask and softmax to attention scores"""
input_dtype = x.dtype
# Create causal mask
mask = torch.tril(torch.ones_like(x), diagonal=-1).flip(dims=[-2, -1])
# Apply mask: set masked positions to -inf
masked = torch.where(mask == 1, -torch.inf, x.to(torch.float32))
# Apply softmax and convert back to original dtype
return torch.nn.functional.softmax(masked, dim=-1).to(input_dtype)
def torch_attention(q, k, v, k_cache, v_cache, pos):
"""PyTorch reference implementation of attention"""
input_dtype = q.dtype
n_q_head = q.shape[0]
n_kv_head = k.shape[0]
# Concatenate key and value caches
k_cache = k_cache[:, :pos, :] # (n_kv_head, pos, head_dim)
v_cache = v_cache[:, :pos, :] # (n_kv_head, pos, head_dim)
k = torch.cat([k_cache, k], dim=1) # (n_kv_head, total_seq_len, head_dim)
v = torch.cat([v_cache, v], dim=1) # (n_kv_head, total_seq_len, head_dim)
total_seq_len = k.shape[1]
head_dim = v.shape[-1]
# Handle grouped query attention (GQA)
if n_q_head != n_kv_head:
q = q.reshape(
n_kv_head, -1, head_dim
) # (n_kv_head, n_group * seq_len, head_dim)
# Scaled dot-product attention
attn_scores = (
torch.einsum("hqd,hkd->hqk", q.to(torch.float32), k.to(torch.float32))
.to(input_dtype)
.reshape(n_q_head, -1, total_seq_len)
) # (n_q_head, seq_len, total_seq_len)
# Scale by sqrt(head_dim)
attn_scores = attn_scores / (head_dim**0.5)
# Apply causal softmax
attn_weights = causal_softmax(attn_scores).reshape(
n_kv_head, -1, total_seq_len
) # (n_kv_head, seq_len, total_seq_len)
# Weighted sum of values
attn_output = (
torch.einsum(
"hqk,hkd->hqd", attn_weights.to(torch.float32), v.to(torch.float32)
)
.to(input_dtype)
.reshape(n_q_head, -1, head_dim)
.permute(1, 0, 2)
) # (seq_len, n_q_head, head_dim)
return attn_output
def parse_test_cases(data):
"""
Parse attention test case data according to format:
(operation_mode, n_q_head, n_kv_head, seq_len, head_dim, pos,
k_cache_buf_len, v_cache_buf_len, q_strides, k_strides, v_strides,
k_cache_strides, v_cache_strides)
"""
operation_mode = data[0]
n_q_head, n_kv_head, seq_len, head_dim, pos = (
data[1],
data[2],
data[3],
data[4],
data[5],
)
k_cache_buf_len, v_cache_buf_len = data[6], data[7]
q_strides = data[8] if len(data) > 8 else None
k_strides = data[9] if len(data) > 9 else None
v_strides = data[10] if len(data) > 10 else None
k_cache_strides = data[11] if len(data) > 11 else None
v_cache_strides = data[12] if len(data) > 12 else None
# Create input specifications
inputs = []
# Query tensor: (n_q_head, seq_len, head_dim)
if q_strides is not None:
inputs.append(
TensorSpec.from_strided_tensor((n_q_head, seq_len, head_dim), q_strides)
)
else:
inputs.append(TensorSpec.from_tensor((n_q_head, seq_len, head_dim)))
# Key tensor: (n_kv_head, seq_len, head_dim)
if k_strides is not None:
inputs.append(
TensorSpec.from_strided_tensor((n_kv_head, seq_len, head_dim), k_strides)
)
else:
inputs.append(TensorSpec.from_tensor((n_kv_head, seq_len, head_dim)))
# Value tensor: (n_kv_head, seq_len, head_dim)
if v_strides is not None:
inputs.append(
TensorSpec.from_strided_tensor((n_kv_head, seq_len, head_dim), v_strides)
)
else:
inputs.append(TensorSpec.from_tensor((n_kv_head, seq_len, head_dim)))
# Key cache: (n_kv_head, k_cache_buf_len, head_dim)
if k_cache_strides is not None:
inputs.append(
TensorSpec.from_strided_tensor(
(n_kv_head, k_cache_buf_len, head_dim), k_cache_strides
)
)
else:
inputs.append(TensorSpec.from_tensor((n_kv_head, k_cache_buf_len, head_dim)))
# Value cache: (n_kv_head, v_cache_buf_len, head_dim)
if v_cache_strides is not None:
inputs.append(
TensorSpec.from_strided_tensor(
(n_kv_head, v_cache_buf_len, head_dim), v_cache_strides
)
)
else:
inputs.append(TensorSpec.from_tensor((n_kv_head, v_cache_buf_len, head_dim)))
# Position (scalar)
inputs.append(TensorSpec.from_scalar(pos))
# Output tensor: (seq_len, n_q_head, head_dim)
output_shape = (seq_len, n_q_head, head_dim)
output = TensorSpec.from_tensor(output_shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Tolerance
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-4, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-5, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 1e-3, "rtol": 5e-2},
}
class OpTest(BaseOperatorTest):
"""Attention test with simplified test case parsing"""
def __init__(self):
super().__init__("Attention")
def get_test_cases(self):
return _TEST_CASES
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
def get_tolerance_map(self):
return _TOLERANCE_MAP
def torch_operator(self, q, k, v, k_cache, v_cache, pos, out=None, **kwargs):
result = torch_attention(q, k, v, k_cache, v_cache, pos)
if out is not None:
out.set_(result)
return out
else:
return result
def infinicore_operator(self, q, k, v, k_cache, v_cache, pos, out=None, **kwargs):
return infinicore.attention(q, k, v, k_cache, v_cache, pos, out=out)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (operation_mode, nbatch, m, n, k, a_strides, b_strides, c_strides)
# If nbatch is None: a_shape=(m, k), b_shape=(k, n), c_shape=(m, n)
# If nbatch is provided: a_shape=(nbatch, m, k), b_shape=(nbatch, k, n), c_shape=(nbatch, m, n)
_TEST_CASES_DATA = [
# Basic 2D matmul
(TestCase.BOTH, None, 2, 4, 3, None, None, None),
(TestCase.BOTH, None, 128, 64, 256, None, None, None),
# Batched matmul
(TestCase.BOTH, 2, 4, 2048, 2048, None, None, None),
(TestCase.BOTH, 4, 48, 6, 64, None, None, None),
# Strided tensors
(TestCase.BOTH, None, 1, 2048, 2048, (4096, 1), (4096, 1), (4096, 1)),
(TestCase.BOTH, None, 6, 2560, 2048, (2048, 1), (1, 2048), (2560, 1)),
# Mixed cases
(TestCase.BOTH, 8, 16, 32, 16, None, None, None),
]
def parse_test_cases(data):
"""
Parse matmul test case data according to format:
(operation_mode, nbatch, m, n, k, a_strides, b_strides, c_strides)
"""
operation_mode = data[0]
nbatch = data[1]
m, n, k = data[2], data[3], data[4]
a_strides = data[5] if len(data) > 5 else None
b_strides = data[6] if len(data) > 6 else None
c_strides = data[7] if len(data) > 7 else None
# Determine shapes based on batch dimension
if nbatch is None:
a_shape = (m, k)
b_shape = (k, n)
c_shape = (m, n)
else:
a_shape = (nbatch, m, k)
b_shape = (nbatch, k, n)
c_shape = (nbatch, m, n)
# Create input specifications
inputs = []
# Tensor a
if a_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(a_shape, a_strides))
else:
inputs.append(TensorSpec.from_tensor(a_shape))
# Tensor b
if b_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(b_shape, b_strides))
else:
inputs.append(TensorSpec.from_tensor(b_shape))
# Output tensor
if c_strides is not None:
output = TensorSpec.from_strided_tensor(c_shape, c_strides)
else:
output = TensorSpec.from_tensor(c_shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Tolerance
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 0, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
}
class OpTest(BaseOperatorTest):
"""Matmul test with simplified test case parsing"""
def __init__(self):
super().__init__("Matmul")
def get_test_cases(self):
return _TEST_CASES
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
def get_tolerance_map(self):
return _TOLERANCE_MAP
def torch_operator(self, a, b, out=None, **kwargs):
return torch.matmul(a, b, out=out)
def infinicore_operator(self, a, b, out=None, **kwargs):
return infinicore.matmul(a, b, out=out)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (operation_mode, y_shape, x_shape, w_shape, y_strides, x_strides)
_TEST_CASES_DATA = [
(TestCase.BOTH, (1, 4), (1, 4), (4,), None, None),
(TestCase.BOTH, (2, 4), (2, 4), (4,), None, None),
(TestCase.BOTH, (2, 2, 4), (2, 2, 4), (4,), None, None),
(TestCase.BOTH, (2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1)),
(TestCase.BOTH, (16, 2048), (16, 2048), (2048,), None, None),
(TestCase.BOTH, (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)),
]
def parse_test_cases(data):
"""
Parse RMSNorm test case data according to format:
(operation_mode, y_shape, x_shape, w_shape, y_strides, x_strides)
"""
operation_mode = data[0]
y_shape = data[1] # Output shape
x_shape = data[2] # Input shape
w_shape = data[3] # Weight shape (1D)
y_strides = data[4] if len(data) > 4 else None
x_strides = data[5] if len(data) > 5 else None
# Create input specifications
inputs = []
# Input tensor x
if x_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(x_shape, x_strides))
else:
inputs.append(TensorSpec.from_tensor(x_shape))
# Weight tensor (1D, always contiguous)
inputs.append(TensorSpec.from_tensor(w_shape))
# Output tensor
if y_strides is not None:
output = TensorSpec.from_strided_tensor(y_shape, y_strides)
else:
output = TensorSpec.from_tensor(y_shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types for individual tensors
_INPUT_DTYPES = [infinicore.float16, infinicore.bfloat16]
_WEIGHT_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Generate all dtype combinations
_DTYPE_COMBINATIONS = []
for input_dtype in _INPUT_DTYPES:
for weight_dtype in _WEIGHT_DTYPES:
_DTYPE_COMBINATIONS.append(
{
"input_0": input_dtype, # x tensor
"input_1": weight_dtype, # weight tensor
"output": input_dtype, # output tensor (same as input)
}
)
# Base data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16]
# Tolerance
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 2e-3, "rtol": 2e-3},
infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
}
# EPSILON constant for RMSNorm
_EPSILON = 1e-5
class OpTest(BaseOperatorTest):
"""RMSNorm test with simplified test case parsing"""
def __init__(self):
super().__init__("RMS_Norm")
def get_test_cases(self):
return _TEST_CASES
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
def get_tolerance_map(self):
return _TOLERANCE_MAP
def get_dtype_combinations(self):
return _DTYPE_COMBINATIONS
def torch_operator(self, x, weight, out=None, **kwargs):
input_dtype = x.dtype
hidden_states = x.to(torch.float32)
scale = hidden_states.pow(2).mean(-1, keepdim=True).add_(_EPSILON).rsqrt_()
result = (hidden_states * scale * weight).to(input_dtype)
if out is not None:
out.set_(result)
return out
else:
return result
def infinicore_operator(self, x, weight, out=None, **kwargs):
return infinicore.rms_norm(x, weight, _EPSILON, out=out)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
import os
import sys
import subprocess
import argparse
from pathlib import Path
def find_ops_directory(start_dir=None):
"""
Find the ops directory by searching from start_dir upwards.
"""
if start_dir is None:
start_dir = Path(__file__).parent
ops_dir = start_dir / "ops"
if ops_dir.exists() and (ops_dir / "rms_norm.py").exists():
return ops_dir
def run_all_op_tests(ops_dir=None, verbose=False, specific_ops=None, extra_args=None):
"""
Run all operator test scripts in the ops directory.
Args:
ops_dir (str, optional): Path to the ops directory. If None, uses the current directory.
verbose (bool): Whether to print detailed output.
specific_ops (list, optional): List of specific operator names to test (e.g., ['add', 'matmul']).
extra_args (list, optional): Extra command line arguments to pass to test scripts.
Returns:
dict: Results dictionary with test names as keys and (success, return_code, output) as values.
"""
if ops_dir is None:
ops_dir = find_ops_directory()
else:
ops_dir = Path(ops_dir)
if not ops_dir.exists():
print(f"Error: Ops directory '{ops_dir}' does not exist.")
return {}
print(f"Looking for test files in: {ops_dir}")
# Find all Python test files (looking for actual operator test files)
test_files = list(ops_dir.glob("*.py"))
# Filter out this script itself and non-operator test files
current_script = Path(__file__).name
test_files = [f for f in test_files if f.name != current_script]
# Further filter to include only files that look like operator tests
# (they typically import infinicore and BaseOperatorTest)
operator_test_files = []
for test_file in test_files:
try:
with open(test_file, "r", encoding="utf-8") as f:
content = f.read()
if "infinicore" in content and "BaseOperatorTest" in content:
operator_test_files.append(test_file)
elif verbose:
print(f" Skipping {test_file.name}: not an operator test file")
except Exception as e:
if verbose:
print(f" Could not read {test_file.name}: {e}")
continue
if specific_ops:
# Filter for specific operators (case insensitive)
filtered_files = []
for test_file in operator_test_files:
test_name = test_file.stem.lower()
if any(op.lower() in test_name for op in specific_ops):
filtered_files.append(test_file)
elif verbose:
print(f" Filtered out {test_file.name}: not in specific_ops list")
operator_test_files = filtered_files
if not operator_test_files:
print(f"No operator test files found in {ops_dir}")
print(f"Available Python files: {[f.name for f in test_files]}")
print(f"Current directory: {Path.cwd()}")
return {}
print(f"Found {len(operator_test_files)} operator test files:")
for test_file in operator_test_files:
print(f" - {test_file.name}")
results = {}
for test_file in operator_test_files:
test_name = test_file.stem
try:
# Run the test script
cmd = [sys.executable, str(test_file)]
# Add extra arguments if provided
if extra_args:
cmd.extend(extra_args)
if verbose:
print(f"Command: {' '.join(cmd)}")
print(f"Working directory: {ops_dir}")
# Always capture output to display it
result = subprocess.run(cmd, cwd=ops_dir, capture_output=True, text=True)
success = result.returncode == 0
results[test_name] = (
success,
result.returncode,
result.stdout,
result.stderr,
)
# Print the output from the test script
if result.stdout:
print(result.stdout)
if result.stderr:
print("STDERR:")
print(result.stderr)
if success:
print(f"✅ {test_name}: PASSED (return code: {result.returncode})")
else:
print(f"❌ {test_name}: FAILED (return code: {result.returncode})")
except Exception as e:
print(f"❌ {test_name}: ERROR - {str(e)}")
results[test_name] = (False, -1, "", str(e))
return results
def print_summary(results):
"""Print a summary of test results."""
print(f"\n{'='*80}")
print("TEST SUMMARY")
print(f"{'='*80}")
if not results:
print("No tests were run.")
return
passed = sum(1 for success, _, _, _ in results.values() if success)
total = len(results)
print(f"Total tests: {total}")
print(f"Passed: {passed}")
print(f"Failed: {total - passed}")
if total > 0:
print(f"Success rate: {passed/total*100:.1f}%")
if passed == total:
print("\n🎉 All tests passed!")
else:
print("\nFailed tests:")
for test_name, (success, returncode, stdout, stderr) in results.items():
if not success:
print(f" - {test_name} (return code: {returncode})")
# Print brief error info for failed tests
if stderr:
error_lines = stderr.strip().split("\n")
if error_lines:
print(f" Error: {error_lines[0]}")
def main():
"""Main entry point with command line argument parsing."""
parser = argparse.ArgumentParser(
description="Run all operator tests in the ops directory", add_help=False
)
# Our script's specific arguments
parser.add_argument(
"--ops-dir", type=str, help="Path to the ops directory (default: auto-detect)"
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Print detailed command information for each test",
)
parser.add_argument(
"--ops", nargs="+", help="Run specific operators only (e.g., --ops add matmul)"
)
parser.add_argument(
"--list",
action="store_true",
help="List all available test files without running them",
)
parser.add_argument(
"-h", "--help", action="store_true", help="Show this help message and exit"
)
# Parse known args first, leave the rest for the test scripts
args, unknown_args = parser.parse_known_args()
if args.help:
parser.print_help()
print("\nExtra arguments that will be passed to test scripts:")
print(" --nvidia, --cpu, --bench, --debug, etc.")
return
# Auto-detect ops directory if not provided
if args.ops_dir is None:
ops_dir = find_ops_directory()
else:
ops_dir = Path(args.ops_dir)
if args.list:
# Just list available test files
test_files = list(ops_dir.glob("*.py"))
current_script = Path(__file__).name
test_files = [f for f in test_files if f.name != current_script]
operator_test_files = []
for test_file in test_files:
try:
with open(test_file, "r", encoding="utf-8") as f:
content = f.read()
if "infinicore" in content and "BaseOperatorTest" in content:
operator_test_files.append(test_file)
except:
continue
if operator_test_files:
print(f"Available operator test files in {ops_dir}:")
for test_file in operator_test_files:
print(f" - {test_file.name}")
else:
print(f"No operator test files found in {ops_dir}")
print(f"Available Python files: {[f.name for f in test_files]}")
return
# Show what extra arguments will be passed
if unknown_args:
print(f"Passing extra arguments to test scripts: {unknown_args}")
# Run all tests
results = run_all_op_tests(
ops_dir=ops_dir,
verbose=args.verbose,
specific_ops=args.ops,
extra_args=unknown_args,
)
print_summary(results)
# Exit with appropriate code
if results and all(success for success, _, _, _ in results.values()):
sys.exit(0)
else:
sys.exit(1)
if __name__ == "__main__":
main()
......@@ -345,7 +345,7 @@ target("_infinicore")
add_files("src/infinicore/context/*.cc")
add_files("src/infinicore/context/*/*.cc")
add_files("src/infinicore/tensor/*.cc")
add_files("src/infinicore/op/*/*.cc")
add_files("src/infinicore/ops/*/*.cc")
add_files("src/infinicore/pybind11/**.cc")
set_installdir("python/infinicore")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment