import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) import infinicore import torch from framework import ( BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner, is_broadcast, ) # Test cases format: (shape, a_strides, b_strides) _TEST_CASES_DATA = [ ((6, 8), None, None), ((8, 4), (16, 1), None), ((5, 5), None, (10, 1)), ((3, 7), (14, 1), (14, 1)), ((10, 3), None, None), ((2, 16), (32, 1), (32, 1)), ] _TOLERANCE_MAP = { infinicore.float16: {"atol": 0, "rtol": 1e-2}, infinicore.float32: {"atol": 0, "rtol": 1e-3}, infinicore.bfloat16: {"atol": 0, "rtol": 5e-2}, } _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] def parse_test_cases(): test_cases = [] for data in _TEST_CASES_DATA: shape, a_strides, b_strides = data a_inplace = not is_broadcast(a_strides) b_inplace = not is_broadcast(b_strides) for dtype in _TENSOR_DTYPES: tol = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3}) a_spec = TensorSpec.from_tensor(shape, a_strides, dtype) b_spec = TensorSpec.from_tensor(shape, b_strides, dtype) test_cases.append( TestCase( inputs=[a_spec, b_spec], kwargs={}, output_spec=None, comparison_target=None, tolerance=tol, description="minimum - OUT_OF_PLACE", ) ) if a_inplace: test_cases.append( TestCase( inputs=[a_spec, b_spec], kwargs={"out": 0}, output_spec=None, comparison_target=0, tolerance=tol, description="minimum - INPLACE(a)", ) ) if b_inplace: test_cases.append( TestCase( inputs=[a_spec, b_spec], kwargs={"out": 1}, output_spec=None, comparison_target=1, tolerance=tol, description="minimum - INPLACE(b)", ) ) return test_cases class OpTest(BaseOperatorTest): """Minimum operator test with simplified implementation""" def __init__(self): super().__init__("Minimum") def get_test_cases(self): return parse_test_cases() def torch_operator(self, *args, **kwargs): return torch.minimum(*args, **kwargs) # def infinicore_operator(self, *args, **kwargs): # """InfiniCore implementation (operator not yet available).""" # return infinicore.minimum(*args, **kwargs) def main(): """Main entry point""" runner = GenericTestRunner(OpTest) runner.run_and_exit() if __name__ == "__main__": main()