import torch import ctypes from ctypes import c_uint64 from libinfiniop import ( LIBINFINIOP, TestTensor, get_test_devices, check_error, test_operator, get_args, debug, get_tolerance, TestWorkspace, InfiniDtype, InfiniDtypeNames, InfiniDeviceNames, infiniopOperatorDescriptor_t, ) from enum import Enum, auto # ============================================================================== # Configuration (Internal Use Only) # ============================================================================== # These are not meant to be imported from other modules _TEST_CASES_ = [ # x_shape = [M,K], w_shape = [N, K], sym, y_shape = [M, N] ((100, 3584), (10752, 3584), True, (100, 10752)), ((1000, 3584), (10752, 3584), True, (1000, 10752)), ((1, 3584), (10752, 3584), True, (1, 10752)), ((2000, 3584), (10752, 3584), True, (2000, 10752)), ] class Inplace(Enum): OUT_OF_PLACE = auto() INPLACE = auto() # Inplace options applied for each test case in _TEST_CASES_ _INPLACE = [ Inplace.INPLACE, ] _TEST_CASES = [ test_case + (inplace_item,) for test_case in _TEST_CASES_ for inplace_item in _INPLACE ] # Data types used for testing _TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16] # Tolerance map for different data types _TOLERANCE_MAP = { InfiniDtype.F16: {"atol": 3e-1, "rtol": 1e-2}, InfiniDtype.BF16: {"atol": 3e-1, "rtol": 1e-2}, } DEBUG = False PROFILE = False NUM_PRERUN = 10 NUM_ITERATIONS = 1000 def mm(x, w, bias, out_dtype): return (torch.matmul(x, w + bias)).to(out_dtype) def scaled_mm(x, w_p, w_s, bias, out_dtype): return ( torch.matmul(x.to(torch.float32), w_p.to(torch.float32)) * w_s.view(1, -1) + bias ).to(out_dtype) def to_int8(tensor: torch.Tensor) -> torch.Tensor: return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): o = torch.matmul(a.to(torch.float32), b.to(torch.float32)) if bias is not None: o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1) + bias else: o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1) return o.to(out_dtype) def per_token_quant_int8_torch(x): x = x.float() absmax = x.abs().max(dim=-1).values absmax = absmax.clamp_min(1e-10).unsqueeze(-1) scale_x = absmax / 127 x_q = x.mul(127 / absmax) x_q = torch.round(x_q).to(torch.int8) return x_q, scale_x def test( handle, device, x_shape, w_shape, symmetric, y_shape, inplace=Inplace.OUT_OF_PLACE, dtype=InfiniDtype.BF16, sync=None, ): print( f"Testing Linear on {InfiniDeviceNames[device]} with x_shape:{x_shape}, w_shape:{w_shape}, symmetric:{symmetric}, inplace:{inplace} dtype:{InfiniDtypeNames[dtype]}" ) M, K = x_shape N = w_shape[0] x = TestTensor(x_shape, None, dtype, device) x_packed = TestTensor(x_shape, None, InfiniDtype.I8, device, mode="zeros") x_scale = TestTensor((M, 1), None, InfiniDtype.F32, device) dev = x.torch_tensor().device weights_packed = to_int8(torch.randn(w_shape, device=dev).t() * 5) weights_scale = torch.randn((N, 1), device=dev, dtype=torch.float32) bias = ( torch.randn( (N,), device=dev, dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, ) * 10 ) w_packed = TestTensor( (K, N), weights_packed.stride(), InfiniDtype.I8, device, mode="manual", set_tensor=weights_packed, ) w_scale = TestTensor( (N, 1), weights_scale.stride(), InfiniDtype.F32, device, mode="manual", set_tensor=weights_scale, ) weights = w_packed.torch_tensor() * w_scale.torch_tensor().view(1, -1) y = TestTensor(y_shape, None, dtype, device) bias = TestTensor( (N,), bias.stride(), dtype, device, mode="manual", set_tensor=bias ) x_mm = x.torch_tensor().to( torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16 ) w_mm = weights.to(torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16) quant_descriptor = infiniopOperatorDescriptor_t() check_error( LIBINFINIOP.infiniopCreatePerChannelQuantI8Descriptor( handle, ctypes.byref(quant_descriptor), x_packed.descriptor, x_scale.descriptor, None, x.descriptor, ) ) quant_workspace_size = c_uint64(0) check_error( LIBINFINIOP.infiniopGetPerChannelQuantI8WorkspaceSize( quant_descriptor, ctypes.byref(quant_workspace_size) ) ) quant_workspace = TestWorkspace(quant_workspace_size.value, x.device) def lib_per_channel_quant_int8(): check_error( LIBINFINIOP.infiniopPerChannelQuantI8( quant_descriptor, quant_workspace.data(), quant_workspace_size.value, x_packed.data(), x_scale.data(), None, x.data(), None, ) ) scaled_mm_descriptor = infiniopOperatorDescriptor_t() check_error( LIBINFINIOP.infiniopCreateI8GemmDescriptor( handle, ctypes.byref(scaled_mm_descriptor), y.descriptor, bias.descriptor, x_packed.descriptor, x_scale.descriptor, w_packed.descriptor, w_scale.descriptor, ) ) scaled_mm_workspace_size = c_uint64(0) check_error( LIBINFINIOP.infiniopGetI8GemmWorkspaceSize( scaled_mm_descriptor, ctypes.byref(scaled_mm_workspace_size) ) ) scaled_mm_workspace = TestWorkspace(scaled_mm_workspace_size.value, x_packed.device) def lib_linear(): check_error( LIBINFINIOP.infiniopI8Gemm( scaled_mm_descriptor, scaled_mm_workspace.data(), scaled_mm_workspace_size.value, y.data(), bias.data(), x_packed.data(), x_scale.data(), w_packed.data(), w_scale.data(), None, ) ) def lib_w8a8int8_linearFunction(): lib_per_channel_quant_int8() lib_linear() def lib_torch_mm(): mm( x_mm, w_mm, bias.torch_tensor(), out_dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, ) x_p, x_s = per_token_quant_int8_torch(x.torch_tensor()) lib_w8a8int8_linearFunction() scaled_mm_torch = torch_scaled_mm( x_p, w_packed.torch_tensor(), x_s, w_scale.torch_tensor(), torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, bias=bias.torch_tensor(), ) mm_torch = scaled_mm( x.torch_tensor(), w_packed.torch_tensor(), w_scale.torch_tensor(), bias.torch_tensor(), out_dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, ) if sync is not None: sync() atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) if DEBUG: debug(y.actual_tensor(), mm_torch, atol=atol, rtol=rtol) # The quantization test did not normalize the test data, leading to large errors; the error check has been temporarily removed. def profile_operation(name, func, device, num_prerun, num_iterations): # Warm up for _ in range(num_prerun): func() torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(num_iterations): func() end.record() torch.cuda.synchronize() elapsed = start.elapsed_time(end) print( f"{name} took {elapsed / num_iterations:.6f} ms over {num_iterations} iterations" ) # Profiling workflow if PROFILE: profile_operation( "PyTorch mm ", lambda: lib_torch_mm(), device, NUM_PRERUN, NUM_ITERATIONS, ) profile_operation( "lib total ", lambda: lib_w8a8int8_linearFunction(), device, NUM_PRERUN, NUM_ITERATIONS, ) profile_operation( "lib quant ", lambda: lib_per_channel_quant_int8(), device, NUM_PRERUN, NUM_ITERATIONS, ) profile_operation( "lib scaled mm ", lambda: lib_linear(), device, NUM_PRERUN, NUM_ITERATIONS, ) check_error(LIBINFINIOP.infiniopDestroyI8GemmDescriptor(scaled_mm_descriptor)) check_error( LIBINFINIOP.infiniopDestroyPerChannelQuantI8Descriptor(quant_descriptor) ) if __name__ == "__main__": args = get_args() # Configure testing options DEBUG = args.debug PROFILE = args.profile NUM_PRERUN = args.num_prerun NUM_ITERATIONS = args.num_iterations for device in get_test_devices(args): test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) print("\033[92mTest passed!\033[0m")