import torch import ctypes from ctypes import c_uint64 from libinfiniop import ( LIBINFINIOP, TestTensor, get_test_devices, check_error, test_operator, get_args, debug, get_tolerance, profile_operation, TestWorkspace, InfiniDtype, InfiniDtypeNames, InfiniDeviceNames, infiniopOperatorDescriptor_t, ) from enum import Enum, auto # ============================================================================== # Configuration (Internal Use Only) # ============================================================================== # These are not meant to be imported from other modules _TEST_CASES_ = [ # x_shape, w_shape, y_shape, alpha, beta ((128, 512), (512, 1024), (128, 1024)), ((256, 1024), (1024, 2048), (256, 2048)), ((1024, 2048), (2048, 1024), (1024, 1024)), ] class Inplace(Enum): OUT_OF_PLACE = auto() INPLACE = auto() # Inplace options applied for each test case in _TEST_CASES_ _INPLACE = [ Inplace.INPLACE, ] _TEST_CASES = [ test_case + (inplace_item,) for test_case in _TEST_CASES_ for inplace_item in _INPLACE ] # Data types used for testing _TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16] # Tolerance map for different data types _TOLERANCE_MAP = { InfiniDtype.F16: {"atol": 3e-1, "rtol": 1e-2}, InfiniDtype.BF16: {"atol": 3e-1, "rtol": 1e-2}, } DEBUG = False PROFILE = False NUM_PRERUN = 10 NUM_ITERATIONS = 100 def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): o = torch.matmul(a.to(torch.float32), b.to(torch.float32)) if bias is not None: o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1) + bias else: o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1) return o.to(out_dtype) def test( handle, device, x_shape, w_shape, y_shape, inplace=Inplace.OUT_OF_PLACE, dtype=InfiniDtype.BF16, sync=None, ): print( f"Testing scaled_mm_int8 on {InfiniDeviceNames[device]} with x_shape:{x_shape}, w_shape:{w_shape}, inplace:{inplace} dtype:{InfiniDtypeNames[dtype]}" ) M, K = x_shape N = w_shape[1] # --- Tensor Descriptor --- # orig: create a random int8 tensor as the reference data source # torch: extract the torch view to adjust layout/stride # final: wrap it back as TestTensor with explicit stride for device execution x_packed_orig = TestTensor( (M, K), None, InfiniDtype.I8, device, mode="randint", randint_low=-128, randint_high=127, ) x_packed_torch = x_packed_orig.torch_tensor() x_packed = TestTensor( (M, K), x_packed_torch.stride(), InfiniDtype.I8, device, mode="manual", set_tensor=x_packed_torch, ) weights_orig = TestTensor( (N, K), None, InfiniDtype.I8, device, mode="randint", randint_low=-128, randint_high=127, ) weights_torch = weights_orig.torch_tensor().t() weights = TestTensor( (K, N), weights_torch.stride(), InfiniDtype.I8, device, mode="manual", set_tensor=weights_torch, ) x_scale_orig = TestTensor((M,), None, InfiniDtype.F32, device, mode="random") x_scale_torch = x_scale_orig.torch_tensor() x_scale = TestTensor( (M,), x_scale_torch.stride(), InfiniDtype.F32, device, mode="manual", set_tensor=x_scale_torch, ) weights_scale_orig = TestTensor((N,), None, InfiniDtype.F32, device, mode="random") weights_scale_torch = weights_scale_orig.torch_tensor() weights_scale = TestTensor( (N,), weights_scale_torch.stride(), InfiniDtype.F32, device, mode="manual", set_tensor=weights_scale_torch, ) bias_orig = TestTensor((N,), None, dtype, device, mode="random") bias_torch = bias_orig.torch_tensor() bias = TestTensor( (N,), bias_torch.stride(), dtype, device, mode="manual", set_tensor=bias_torch ) y = TestTensor(y_shape, None, dtype, device, mode="zeros") ans = torch_scaled_mm( x_packed.torch_tensor(), weights.torch_tensor(), x_scale.torch_tensor(), weights_scale.torch_tensor(), out_dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, bias=bias.torch_tensor(), ) descriptor = infiniopOperatorDescriptor_t() check_error( LIBINFINIOP.infiniopCreateI8GemmDescriptor( handle, ctypes.byref(descriptor), y.descriptor, bias.descriptor, x_packed.descriptor, x_scale.descriptor, weights.descriptor, weights_scale.descriptor, ) ) workspace_size = c_uint64(0) check_error( LIBINFINIOP.infiniopGetI8GemmWorkspaceSize( descriptor, ctypes.byref(workspace_size) ) ) workspace = TestWorkspace(workspace_size.value, x_packed.device) def lib_linear(): check_error( LIBINFINIOP.infiniopI8Gemm( descriptor, workspace.data(), workspace_size.value, y.data(), bias.data(), x_packed.data(), x_scale.data(), weights.data(), weights_scale.data(), None, ) ) lib_linear() if sync is not None: sync() atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) if DEBUG: debug(y.actual_tensor(), ans, atol=atol, rtol=rtol) assert torch.allclose(y.actual_tensor(), ans, atol=atol, rtol=rtol) # Profiling workflow if PROFILE: # fmt: off profile_operation( "PyTorch", lambda: torch_scaled_mm( x_packed.torch_tensor(), weights.torch_tensor(), x_scale.torch_tensor(), weights_scale.torch_tensor(), out_dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, bias=bias.torch_tensor() ), device, NUM_PRERUN, NUM_ITERATIONS ) profile_operation(" lib", lambda: lib_linear(), device, NUM_PRERUN, NUM_ITERATIONS) # fmt: on check_error(LIBINFINIOP.infiniopDestroyI8GemmDescriptor(descriptor)) if __name__ == "__main__": args = get_args() # Configure testing options DEBUG = args.debug PROFILE = args.profile NUM_PRERUN = args.num_prerun NUM_ITERATIONS = args.num_iterations for device in get_test_devices(args): # muDNN(v3101): INT8 quantized multiplication → BF16 output. # Moore backend: BF16 output only. if args.moore == True: _TENSOR_DTYPES_MOORE = [InfiniDtype.BF16] test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES_MOORE) else: test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) print("\033[92mTest passed!\033[0m")