import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) import infinicore import torch from framework import ( BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner, is_broadcast, ) # Test cases format: (in_shape, in_strides_or_None, min_val_or_None, max_val_or_None) _TEST_CASES_DATA = [ ((13, 4), None, -1.0, 1.0), ((13, 4), (10, 1), -0.5, 0.5), ((8, 8, 8), None, -2.0, 2.0), ] _TOLERANCE_MAP = { infinicore.float16: {"atol": 1e-3, "rtol": 1e-2}, infinicore.float32: {"atol": 1e-5, "rtol": 1e-4}, infinicore.bfloat16: {"atol": 1e-2, "rtol": 5e-2}, } _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] def parse_test_cases(): """hardtanh(input, min_val=-1.0, max_val=1.0, inplace=False)""" test_cases = [] for data in _TEST_CASES_DATA: shape = data[0] in_strides = data[1] if len(data) > 1 else None minv = data[2] if len(data) > 2 else -1.0 maxv = data[3] if len(data) > 3 else 1.0 input_supports_inplace = not is_broadcast(in_strides) for dtype in _TENSOR_DTYPES: tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4}) input_spec = TensorSpec.from_tensor(shape, in_strides, dtype) kwargs = {"min_val": minv, "max_val": maxv} test_cases.append( TestCase( inputs=[input_spec], kwargs=kwargs, output_spec=None, comparison_target=None, tolerance=tolerance, description=f"Hardtanh - OUT_OF_PLACE", ) ) if input_supports_inplace: inplace_kwargs = {"min_val": minv, "max_val": maxv, "inplace": True} test_cases.append( TestCase( inputs=[input_spec], kwargs=inplace_kwargs, output_spec=None, comparison_target=0, tolerance=tolerance, description=f"Hardtanh - INPLACE", ) ) return test_cases class OpTest(BaseOperatorTest): """Hardtanh operator test with simplified implementation""" def __init__(self): super().__init__("Hardtanh") def get_test_cases(self): return parse_test_cases() def torch_operator(self, *args, **kwargs): return torch.nn.functional.hardtanh(*args, **kwargs) # def infinicore_operator(self, *args, **kwargs): # """InfiniCore implementation (operator not yet available).""" # return infinicore.nn.functional.hardtanh(*args, **kwargs) def main(): """Main entry point""" runner = GenericTestRunner(OpTest) runner.run_and_exit() if __name__ == "__main__": main()