Unverified Commit 2790a7b2 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #308 from InfiniTensor/issue/307

issue/307 unify test tensor creation in pytorch tests
parents 70146b74 f62e952e
from ctypes import c_int, Structure, POINTER
class TensorDescriptor(Structure):
_fields_ = []
infiniopTensorDescriptor_t = POINTER(TensorDescriptor)
class Handle(Structure):
_fields_ = [("device", c_int), ("device_id", c_int)]
infiniopHandle_t = POINTER(Handle)
class OpDescriptor(Structure):
_fields_ = [("device", c_int), ("device_id", c_int)]
infiniopOperatorDescriptor_t = POINTER(OpDescriptor)
from typing import Sequence
import torch import torch
import ctypes import ctypes
from .datatypes import * from .datatypes import *
from .devices import * from .devices import *
from typing import Sequence from .liboperators import infiniopTensorDescriptor_t, LIBINFINIOP, infiniopHandle_t
from .liboperators import infiniopTensorDescriptor_t, CTensor, infiniopHandle_t
def check_error(status): def check_error(status):
...@@ -11,79 +11,173 @@ def check_error(status): ...@@ -11,79 +11,173 @@ def check_error(status):
raise Exception("Error code " + str(status)) raise Exception("Error code " + str(status))
def to_tensor(tensor, lib, force_unsigned=False, force_shape=None, force_strides=None): class CTensor:
""" def __init__(self, dt: InfiniDtype, shape, strides):
Convert a PyTorch tensor to a library Tensor(descriptor, data). self.descriptor = infiniopTensorDescriptor_t()
""" self.dt = dt
import torch self.ndim = len(shape)
if strides is None:
strides = [1 for _ in shape]
for i in range(self.ndim - 2, -1, -1):
strides[i] = strides[i + 1] * shape[i + 1]
assert self.ndim == len(strides)
self.c_shape = (ctypes.c_size_t * self.ndim)(*shape)
self.c_strides = (ctypes.c_ssize_t * self.ndim)(*strides)
LIBINFINIOP.infiniopCreateTensorDescriptor(
ctypes.byref(self.descriptor),
self.ndim,
self.c_shape,
self.c_strides,
self.dt,
)
ndim = tensor.ndimension() def destroy_desc(self):
if force_shape is not None: if self.descriptor is not None:
ndim = len(force_shape) LIBINFINIOP.infiniopDestroyTensorDescriptor(self.descriptor)
shape = (ctypes.c_size_t * ndim)(*force_shape) self.descriptor = None
else:
shape = (ctypes.c_size_t * ndim)(*tensor.shape)
if force_strides is not None: class TestTensor(CTensor):
ndim = len(force_strides) def __init__(
strides = (ctypes.c_int64 * ndim)(*force_strides) self,
else: shape,
strides = (ctypes.c_int64 * ndim)(*(tensor.stride())) strides,
# fmt: off dt: InfiniDtype,
dt = ( device: InfiniDeviceEnum,
InfiniDtype.I8 if tensor.dtype == torch.int8 else mode="random",
InfiniDtype.I16 if tensor.dtype == torch.int16 else scale=None,
InfiniDtype.I32 if tensor.dtype == torch.int32 else bias=None,
InfiniDtype.I64 if tensor.dtype == torch.int64 else set_tensor=None,
InfiniDtype.U8 if tensor.dtype == torch.uint8 else ):
InfiniDtype.F16 if tensor.dtype == torch.float16 else self.dt = dt
InfiniDtype.BF16 if tensor.dtype == torch.bfloat16 else self.device = device
InfiniDtype.F32 if tensor.dtype == torch.float32 else self.shape = shape
InfiniDtype.F64 if tensor.dtype == torch.float64 else self.strides = strides
# TODO: These following types may not be supported by older torch_shape = []
# versions of PyTorch. torch_strides = [] if strides is not None else None
InfiniDtype.U16 if tensor.dtype == torch.uint16 else for i in range(len(shape)):
InfiniDtype.U32 if tensor.dtype == torch.uint32 else if strides is not None and strides[i] == 0:
InfiniDtype.U64 if tensor.dtype == torch.uint64 else torch_shape.append(1)
None torch_strides.append(1)
) elif strides is not None and strides[i] != 0:
torch_shape.append(shape[i])
if force_unsigned: torch_strides.append(strides[i])
dt = ( else:
InfiniDtype.U8 if dt == InfiniDtype.I8 else torch_shape.append(shape[i])
InfiniDtype.U16 if dt == InfiniDtype.I16 else if mode == "random":
InfiniDtype.U32 if dt == InfiniDtype.I32 else self._torch_tensor = torch.rand(
InfiniDtype.U64 if dt == InfiniDtype.I64 else torch_shape, dtype=to_torch_dtype(dt), device=torch_device_map[device]
dt )
elif mode == "zeros":
self._torch_tensor = torch.zeros(
torch_shape, dtype=to_torch_dtype(dt), device=torch_device_map[device]
)
elif mode == "ones":
self._torch_tensor = torch.ones(
torch_shape, dtype=to_torch_dtype(dt), device=torch_device_map[device]
)
elif mode == "manual":
assert set_tensor is not None
assert torch_shape == list(set_tensor.shape)
assert torch_strides == list(set_tensor.stride())
self._torch_tensor = set_tensor.to(to_torch_dtype(dt)).to(
torch_device_map[device]
)
else:
raise ValueError("Unsupported mode")
if scale is not None:
self._torch_tensor *= scale
if bias is not None:
self._torch_tensor += bias
if strides is not None:
self._data_tensor = rearrange_tensor(self._torch_tensor, torch_strides)
else:
self._data_tensor = self._torch_tensor.clone()
super().__init__(self.dt, shape, strides)
def torch_tensor(self):
return self._torch_tensor
def actual_tensor(self):
return self._data_tensor
def data(self):
return self._data_tensor.data_ptr()
def is_broadcast(self):
return self.strides is not None and 0 in self.strides
@staticmethod
def from_torch(torch_tensor, dt: InfiniDtype, device: InfiniDeviceEnum):
shape_ = list(torch_tensor.shape)
strides_ = list(torch_tensor.stride())
return TestTensor(
shape_, strides_, dt, device, mode="manual", set_tensor=torch_tensor
) )
# fmt: on
assert dt is not None
# Create TensorDecriptor
tensor_desc = infiniopTensorDescriptor_t()
lib.infiniopCreateTensorDescriptor(
ctypes.byref(tensor_desc), ndim, shape, strides, dt
)
# Create Tensor
return CTensor(tensor_desc, tensor)
def to_torch_dtype(dt: InfiniDtype, compatability_mode=False):
if dt == InfiniDtype.I8:
return torch.int8
elif dt == InfiniDtype.I16:
return torch.int16
elif dt == InfiniDtype.I32:
return torch.int32
elif dt == InfiniDtype.I64:
return torch.int64
elif dt == InfiniDtype.U8:
return torch.uint8
elif dt == InfiniDtype.F16:
return torch.float16
elif dt == InfiniDtype.BF16:
return torch.bfloat16
elif dt == InfiniDtype.F32:
return torch.float32
elif dt == InfiniDtype.F64:
return torch.float64
# TODO: These following types may not be supported by older
# versions of PyTorch. Use compatability mode to convert them.
elif dt == InfiniDtype.U16:
return torch.int16 if compatability_mode else torch.uint16
elif dt == InfiniDtype.U32:
return torch.int32 if compatability_mode else torch.uint32
elif dt == InfiniDtype.U64:
return torch.int64 if compatability_mode else torch.uint64
else:
raise ValueError("Unsupported data type")
def create_workspace(size, torch_device):
print(f" - Workspace Size : {size}")
if size == 0:
return None
import torch
return torch.zeros(size=(size,), dtype=torch.uint8, device=torch_device) class TestWorkspace:
def __init__(self, size, device):
if size != 0:
self.tensor = TestTensor((size,), None, InfiniDtype.U8, device, mode="ones")
else:
self.tensor = None
self._size = size
def data(self):
if self.tensor is not None:
return self.tensor.data()
else:
return None
def size(self):
return ctypes.c_uint64(self._size)
def create_handle(lib):
def create_handle():
handle = infiniopHandle_t() handle = infiniopHandle_t()
check_error(lib.infiniopCreateHandle(ctypes.byref(handle))) check_error(LIBINFINIOP.infiniopCreateHandle(ctypes.byref(handle)))
return handle return handle
def destroy_handle(lib, handle): def destroy_handle(handle):
check_error(lib.infiniopDestroyHandle(handle)) check_error(LIBINFINIOP.infiniopDestroyHandle(handle))
def rearrange_tensor(tensor, new_strides): def rearrange_tensor(tensor, new_strides):
...@@ -132,13 +226,6 @@ def rearrange_tensor(tensor, new_strides): ...@@ -132,13 +226,6 @@ def rearrange_tensor(tensor, new_strides):
return new_tensor return new_tensor
def rearrange_if_needed(tensor, stride):
"""
Rearrange a PyTorch tensor if the given stride is not None.
"""
return rearrange_tensor(tensor, stride) if stride is not None else tensor
def get_args(): def get_args():
import argparse import argparse
...@@ -232,6 +319,7 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): ...@@ -232,6 +319,7 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
If True, the function will print detailed information about any discrepancies between the tensors. If True, the function will print detailed information about any discrepancies between the tensors.
""" """
import numpy as np import numpy as np
# 如果是BF16,全部转成FP32再比对 # 如果是BF16,全部转成FP32再比对
if actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16: if actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
actual = actual.to(torch.float32) actual = actual.to(torch.float32)
...@@ -316,7 +404,9 @@ def debug_all( ...@@ -316,7 +404,9 @@ def debug_all(
assert passed, "\033[31mThe condition has not been satisfied\033[0m" assert passed, "\033[31mThe condition has not been satisfied\033[0m"
def print_discrepancy(actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True): def print_discrepancy(
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True
):
if actual.shape != expected.shape: if actual.shape != expected.shape:
raise ValueError("Tensors must have the same shape to compare.") raise ValueError("Tensors must have the same shape to compare.")
...@@ -329,8 +419,12 @@ def print_discrepancy(actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbo ...@@ -329,8 +419,12 @@ def print_discrepancy(actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbo
expected_isnan = torch.isnan(expected) expected_isnan = torch.isnan(expected)
# Calculate the difference mask based on atol and rtol # Calculate the difference mask based on atol and rtol
nan_mismatch = actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan nan_mismatch = (
diff_mask = nan_mismatch | (torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))) actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
)
diff_mask = nan_mismatch | (
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
)
diff_indices = torch.nonzero(diff_mask, as_tuple=False) diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected delta = actual - expected
...@@ -427,35 +521,33 @@ def profile_operation(desc, func, torch_device, NUM_PRERUN, NUM_ITERATIONS): ...@@ -427,35 +521,33 @@ def profile_operation(desc, func, torch_device, NUM_PRERUN, NUM_ITERATIONS):
print(f" {desc} time: {elapsed * 1000 :6f} ms") print(f" {desc} time: {elapsed * 1000 :6f} ms")
def test_operator(lib, device, test_func, test_cases, tensor_dtypes): def test_operator(device, test_func, test_cases, tensor_dtypes):
""" """
Testing a specified operator on the given device with the given test function, test cases, and tensor data types. Testing a specified operator on the given device with the given test function, test cases, and tensor data types.
Arguments: Arguments:
---------- ----------
- lib (ctypes.CDLL): The library object containing the operator implementations.
- device (InfiniDeviceEnum): The device on which the operator should be tested. See device.py. - device (InfiniDeviceEnum): The device on which the operator should be tested. See device.py.
- test_func (function): The test function to be executed for each test case. - test_func (function): The test function to be executed for each test case.
- test_cases (list of tuples): A list of test cases, where each test case is a tuple of parameters - test_cases (list of tuples): A list of test cases, where each test case is a tuple of parameters
to be passed to `test_func`. to be passed to `test_func`.
- tensor_dtypes (list): A list of tensor data types (e.g., `torch.float32`) to test. - tensor_dtypes (list): A list of tensor data types (e.g., `torch.float32`) to test.
""" """
lib.infinirtSetDevice(device, ctypes.c_int(0)) LIBINFINIOP.infinirtSetDevice(device, ctypes.c_int(0))
handle = create_handle(lib) handle = create_handle()
tensor_dtypes = filter_tensor_dtypes_by_device(device, tensor_dtypes) tensor_dtypes = filter_tensor_dtypes_by_device(device, tensor_dtypes)
try: try:
for test_case in test_cases: for test_case in test_cases:
for tensor_dtype in tensor_dtypes: for tensor_dtype in tensor_dtypes:
test_func( test_func(
lib,
handle, handle,
infiniDeviceEnum_str_map[device], device,
*test_case, *test_case,
tensor_dtype, tensor_dtype,
get_sync_func(device), get_sync_func(device),
) )
finally: finally:
destroy_handle(lib, handle) destroy_handle(handle)
def get_test_devices(args): def get_test_devices(args):
...@@ -506,7 +598,7 @@ def get_test_devices(args): ...@@ -506,7 +598,7 @@ def get_test_devices(args):
def get_sync_func(device): def get_sync_func(device):
import torch import torch
device_str = infiniDeviceEnum_str_map[device] device_str = torch_device_map[device]
if device == InfiniDeviceEnum.CPU: if device == InfiniDeviceEnum.CPU:
sync = None sync = None
......
from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64
import ctypes
import sys
import os
import time
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
check_error,
)
from operatorspy.tests.test_utils import get_args
import torch
from typing import Tuple
# constant for control whether profile the pytorch and lib functions
# NOTE: need to manually add synchronization function to the lib function,
# e.g., cudaDeviceSynchronize() for CUDA
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class MaxPoolDescriptor(Structure):
_fields_ = [("device", c_int32)]
infiniopMaxPoolDescriptor_t = POINTER(MaxPoolDescriptor)
def pool(x, k, padding, stride, dilation=1):
pooling_layers = {
1: torch.nn.MaxPool1d,
2: torch.nn.MaxPool2d,
3: torch.nn.MaxPool3d,
}
ndim = len(x.shape) - 2
if ndim not in pooling_layers:
print("Error: Pytorch -> Unsupported tensor dimension")
return None
ans = pooling_layers[ndim](k, stride=stride, padding=padding, dilation=dilation)(x)
if PROFILE:
torch.cuda.synchronize()
return ans
def inferShape(x_shape, kernel_shape, padding, strides):
assert (
len(x_shape) - 2 == len(kernel_shape) == len(padding) == len(strides)
), "kernel, pads, and strides should have the same length; the length of input x should be 2 more than that of kernel"
input_shape = x_shape[2:]
output_shape = []
for dim, k, p, s in zip(input_shape, kernel_shape, padding, strides):
output_dim = (dim + 2 * p - k) // s + 1
output_shape.append(output_dim)
return x_shape[:2] + tuple(output_shape)
# convert a python tuple to a ctype void pointer
def tuple_to_void_p(py_tuple: Tuple):
array = ctypes.c_int64 * len(py_tuple)
data_array = array(*py_tuple)
return ctypes.cast(data_array, ctypes.c_void_p)
def test(
lib,
handle,
torch_device,
x_shape,
k_shape,
padding,
strides,
tensor_dtype=torch.float16,
sync=None
):
print(
f"Testing MaxPool on {torch_device} with x_shape:{x_shape} kernel_shape:{k_shape} padding:{padding} strides:{strides} dtype:{tensor_dtype}"
)
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device)
y = torch.rand(
inferShape(x_shape, k_shape, padding, strides), dtype=tensor_dtype
).to(torch_device)
for i in range(NUM_PRERUN if PROFILE else 1):
ans = pool(x, k_shape, padding, strides)
if PROFILE:
start_time = time.time()
for i in range(NUM_ITERATIONS):
_ = pool(x, k_shape, padding, strides)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f"pytorch time: {elapsed :6f}")
x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib)
if sync is not None:
sync()
descriptor = infiniopMaxPoolDescriptor_t()
check_error(
lib.infiniopCreateMaxPoolDescriptor(
handle,
ctypes.byref(descriptor),
y_tensor.descriptor,
x_tensor.descriptor,
tuple_to_void_p(k_shape),
tuple_to_void_p(padding),
tuple_to_void_p(strides),
len(k_shape),
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate()
y_tensor.descriptor.contents.invalidate()
workspaceSize = ctypes.c_uint64(0)
check_error(
lib.infiniopGetMaxPoolWorkspaceSize(descriptor, ctypes.byref(workspaceSize))
)
workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(
torch_device
)
workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8))
for i in range(NUM_PRERUN if PROFILE else 1):
check_error(
lib.infiniopMaxPool(
descriptor,
workspace_ptr,
workspaceSize,
y_tensor.data,
x_tensor.data,
None,
)
)
if PROFILE:
start_time = time.time()
for i in range(NUM_ITERATIONS):
check_error(
lib.infiniopMaxPool(
descriptor,
workspace_ptr,
workspaceSize,
y_tensor.data,
x_tensor.data,
None,
)
)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f" lib time: {elapsed :6f}")
assert torch.allclose(y, ans, atol=0, rtol=1e-3)
check_error(lib.infiniopDestroyMaxPoolDescriptor(descriptor))
def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for x_shape, kernel_shape, padding, strides in test_cases:
# fmt: off
test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
test(lib, handle, "cpu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for x_shape, kernel_shape, padding, strides in test_cases:
# fmt: off
test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
test(lib, handle, "cuda", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for x_shape, kernel_shape, padding, strides in test_cases:
# fmt: off
test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle)
if __name__ == "__main__":
test_cases = [
# x_shape, kernel_shape, padding, strides
((1, 1, 10), (3,), (1,), (1,)),
((32, 3, 224, 224), (3, 3), (1, 1), (2, 2)),
((1, 1, 16, 16, 16), (5, 5, 5), (2, 2, 2), (2, 2, 2)),
]
args = get_args()
lib = open_lib()
lib.infiniopCreateMaxPoolDescriptor.restype = c_int32
lib.infiniopCreateMaxPoolDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopMaxPoolDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_void_p,
c_void_p,
c_void_p,
c_uint64,
]
lib.infiniopGetMaxPoolWorkspaceSize.restype = c_int32
lib.infiniopGetMaxPoolWorkspaceSize.argtypes = [
infiniopMaxPoolDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopMaxPool.restype = c_int32
lib.infiniopMaxPool.argtypes = [
infiniopMaxPoolDescriptor_t,
c_void_p,
c_uint64,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyMaxPoolDescriptor.restype = c_int32
lib.infiniopDestroyMaxPoolDescriptor.argtypes = [
infiniopMaxPoolDescriptor_t,
]
if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if args.bang:
test_bang(lib, test_cases)
if not (args.cpu or args.cuda or args.bang):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float, c_bool
import ctypes
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
CTensor,
DeviceEnum,
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
check_error,
rearrange_tensor,
create_workspace,
)
from operatorspy.tests.test_utils import get_args
import torch
import torch.nn as nn
class MLPDescriptor(Structure):
_fields_ = [("device", c_int32)]
infiniopMLPDescriptor_t = POINTER(MLPDescriptor)
def swiglu(a, b):
return a * b / (1 + torch.exp(-b.float()).to(b.dtype))
def mlp(y, x, w12, w3, alpha, residual):
input_dtype = x.dtype
intermediate_size = w3.shape[0]
a = torch.matmul(
x.to(torch.float32), w12[:, intermediate_size:].to(torch.float32)
).to(input_dtype)
b = torch.matmul(
x.to(torch.float32), w12[:, 0:intermediate_size].to(torch.float32)
).to(input_dtype)
c = swiglu(a, b)
d = torch.matmul(c.to(torch.float32), alpha * w3.to(torch.float32)).to(input_dtype)
out = d + y if residual else d
return out
def test(
lib,
handle,
torch_device,
num_tokens,
hidden_size,
intermediate_size,
alpha,
residual,
dtype=torch.float16,
x_stride=None,
y_stride=None,
w12_stride=None,
w3_stride=None,
sync=None
):
print(
f"Testing MLP on {torch_device} with num_tokens:{num_tokens} hidden_size:{hidden_size} intermediate_size:{intermediate_size}"
f" alpha:{alpha} residual:{residual} dtype:{dtype} x_stride:{x_stride} y_stride:{y_stride} w12_stride:{w12_stride} w3_stride:{w3_stride}"
)
y = torch.rand([num_tokens, hidden_size], dtype=dtype).to(torch_device) * 0.01
x = torch.rand([num_tokens, hidden_size], dtype=dtype).to(torch_device) * 0.01
w12 = (
torch.rand([hidden_size, 2 * intermediate_size], dtype=dtype).to(torch_device)
* 0.01
)
w3 = (
torch.rand([intermediate_size, hidden_size], dtype=dtype).to(torch_device)
* 0.01
)
ans = mlp(y, x, w12, w3, alpha, residual)
if x_stride is not None:
x = rearrange_tensor(x, x_stride)
if y_stride is not None:
y = rearrange_tensor(y, y_stride)
if w12_stride is not None:
w12 = rearrange_tensor(w12, w12_stride)
if w3_stride is not None:
w3 = rearrange_tensor(w3, w3_stride)
y_tensor = to_tensor(y, lib)
x_tensor = to_tensor(x, lib)
w12_tensor = to_tensor(w12, lib)
w3_tensor = to_tensor(w3, lib)
if sync is not None:
sync()
descriptor = infiniopMLPDescriptor_t()
check_error(
lib.infiniopCreateMLPDescriptor(
handle,
ctypes.byref(descriptor),
y_tensor.descriptor,
x_tensor.descriptor,
w12_tensor.descriptor,
w3_tensor.descriptor,
alpha,
residual,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
y_tensor.descriptor.contents.invalidate()
x_tensor.descriptor.contents.invalidate()
w12_tensor.descriptor.contents.invalidate()
w3_tensor.descriptor.contents.invalidate()
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetMLPWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, x.device)
check_error(
lib.infiniopMLP(
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
y_tensor.data,
x_tensor.data,
w12_tensor.data,
w3_tensor.data,
None,
)
)
assert torch.allclose(y, ans, atol=0, rtol=2e-2)
check_error(lib.infiniopDestroyMLPDescriptor(descriptor))
def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for (
num_tokens,
hidden_size,
intermediate_size,
alpha,
residual,
dtype,
x_stride,
y_stride,
w12_stride,
w3_stride,
) in test_cases:
test(
lib,
handle,
"cpu",
num_tokens,
hidden_size,
intermediate_size,
alpha,
residual,
dtype,
x_stride,
y_stride,
w12_stride,
w3_stride,
)
destroy_handle(lib, handle)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for (
num_tokens,
hidden_size,
intermediate_size,
alpha,
residual,
dtype,
x_stride,
y_stride,
w12_stride,
w3_stride,
) in test_cases:
test(
lib,
handle,
"cuda",
num_tokens,
hidden_size,
intermediate_size,
alpha,
residual,
dtype,
x_stride,
y_stride,
w12_stride,
w3_stride,
)
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for (
num_tokens,
hidden_size,
intermediate_size,
alpha,
residual,
dtype,
x_stride,
y_stride,
w12_stride,
w3_stride,
) in test_cases:
test(
lib,
handle,
"mlu",
num_tokens,
hidden_size,
intermediate_size,
alpha,
residual,
dtype,
x_stride,
y_stride,
w12_stride,
w3_stride,
)
destroy_handle(lib, handle)
if __name__ == "__main__":
test_cases = [
# num_tokens, hidden_size, intermediate_size, alpha, residual, dtype, x_stride, y_stride, w12_stride, w3_stride
(4, 4096, 11008, 1.0, True, torch.float16, None, None, None, None),
(4, 4096, 11008, 1.0, True, torch.float16, [8192, 1], [8192, 1], None, None),
(
4,
4096,
11008,
1.0,
True,
torch.float16,
None,
None,
[1, 4096],
[1, 11008],
),
(4, 4096, 11008, 1.0, False, torch.float16, None, None, None, None),
(4, 4096, 11008, 1.0, False, torch.float16, [8192, 1], [8192, 1], None, None),
]
args = get_args()
lib = open_lib()
lib.infiniopCreateMLPDescriptor.restype = c_int32
lib.infiniopCreateMLPDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopMLPDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_float,
c_bool,
]
lib.infiniopGetMLPWorkspaceSize.restype = c_int32
lib.infiniopGetMLPWorkspaceSize.argtypes = [
infiniopMLPDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopMLP.restype = c_int32
lib.infiniopMLP.argtypes = [
infiniopMLPDescriptor_t,
c_void_p,
c_uint64,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyMLPDescriptor.restype = c_int32
lib.infiniopDestroyMLPDescriptor.argtypes = [
infiniopMLPDescriptor_t,
]
if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if args.bang:
test_bang(lib, test_cases)
if not (args.cpu or args.cuda or args.bang):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
import torch import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64 from ctypes import c_uint64
from libinfiniop import ( from libinfiniop import (
infiniopHandle_t, LIBINFINIOP,
infiniopTensorDescriptor_t, TestTensor,
open_lib,
to_tensor,
get_test_devices, get_test_devices,
check_error, check_error,
rearrange_if_needed,
test_operator, test_operator,
get_args, get_args,
debug, debug,
get_tolerance, get_tolerance,
profile_operation, profile_operation,
create_workspace, TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
) )
from enum import Enum, auto from enum import Enum, auto
...@@ -58,126 +59,92 @@ _TEST_CASES = [ ...@@ -58,126 +59,92 @@ _TEST_CASES = [
] ]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32] _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 1e-3, "rtol": 1e-3}, InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3},
torch.float32: {"atol": 1e-7, "rtol": 1e-7}, InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7},
} }
DEBUG = False DEBUG = False
PROFILE = False PROFILE = False
NUM_PRERUN = 10 NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
class MulDescriptor(Structure): def mul(c, a, b):
_fields_ = [("device", c_int32)] torch.mul(a, b, out=c)
infiniopMulDescriptor_t = POINTER(MulDescriptor)
def mul(x, y):
return torch.mul(x, y)
def process_tensors(c, c_strides, a, a_stride, b, b_stride, inplace):
"""
rearrange the tensors if needed and apply the inplace config.
if inplace is true and the output (i.e., c) is placed to the broadcasted input,
the inplace config is ignored and out-of-place is used
"""
original_c_strides = c_strides if c_strides else c.stride()
def _rearrange(tensor, strides):
if strides and 0 in strides:
tensor.set_(tensor.untyped_storage(), 0, tensor.shape, strides)
return tensor
else:
return rearrange_if_needed(tensor, strides)
a, b, c = [
_rearrange(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_strides])
]
c = (
c
if inplace == Inplace.OUT_OF_PLACE
else (a if inplace == Inplace.INPLACE_A else b)
)
# if inplace is true and c has broadcasted config, reset it to the original unbroadcasted strides
if 0 in c.stride():
c.set_(c.untyped_storage(), 0, c.shape, original_c_strides)
return a, b, c
def test( def test(
lib,
handle, handle,
torch_device, device,
shape, shape,
a_stride=None, a_stride=None,
b_stride=None, b_stride=None,
c_stride=None, c_stride=None,
inplace=Inplace.OUT_OF_PLACE, inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float16, dtype=InfiniDtype.F16,
sync=None, sync=None,
): ):
a = TestTensor(shape, a_stride, dtype, device)
b = TestTensor(shape, b_stride, dtype, device)
if inplace == Inplace.INPLACE_A:
if c_stride is not None and c_stride != a_stride:
return
c = a
elif inplace == Inplace.INPLACE_B:
if c_stride is not None and c_stride != b_stride:
return
c = b
else:
c = TestTensor(shape, c_stride, dtype, device)
if c.is_broadcast():
return
print( print(
f"Testing Mul on {torch_device} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} " f"Testing Mul on {InfiniDeviceNames[device]} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} "
f"dtype:{dtype} inplace:{inplace}" f"dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}"
) )
mul(c.torch_tensor(), a.torch_tensor(), b.torch_tensor())
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
c = torch.rand(shape, dtype=dtype).to(torch_device)
a, b, c = process_tensors(c, c_stride, a, a_stride, b, b_stride, inplace)
ans = mul(a, b)
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
c_tensor = (
to_tensor(c, lib)
if inplace == Inplace.OUT_OF_PLACE
else (a_tensor if inplace == Inplace.INPLACE_A else b_tensor)
)
if sync is not None: if sync is not None:
sync() sync()
descriptor = infiniopMulDescriptor_t() descriptor = infiniopOperatorDescriptor_t()
check_error( check_error(
lib.infiniopCreateMulDescriptor( LIBINFINIOP.infiniopCreateMulDescriptor(
handle, handle,
ctypes.byref(descriptor), ctypes.byref(descriptor),
c_tensor.descriptor, c.descriptor,
a_tensor.descriptor, a.descriptor,
b_tensor.descriptor, b.descriptor,
) )
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [a_tensor, b_tensor, c_tensor]: for tensor in [a, b, c]:
tensor.destroyDesc(lib) tensor.destroy_desc()
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
lib.infiniopGetMulWorkspaceSize(descriptor, ctypes.byref(workspace_size)) LIBINFINIOP.infiniopGetMulWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
) )
workspace = create_workspace(workspace_size.value, c.device) workspace = TestWorkspace(workspace_size.value, c.device)
def lib_mul(): def lib_mul():
check_error( check_error(
lib.infiniopMul( LIBINFINIOP.infiniopMul(
descriptor, descriptor,
workspace.data_ptr() if workspace is not None else None, workspace.data(),
workspace_size.value, workspace_size.value,
c_tensor.data, c.data(),
a_tensor.data, a.data(),
b_tensor.data, b.data(),
None, None,
) )
) )
...@@ -186,52 +153,20 @@ def test( ...@@ -186,52 +153,20 @@ def test(
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug(c, ans, atol=atol, rtol=rtol) debug(c.actual_tensor(), c.torch_tensor(), atol=atol, rtol=rtol)
assert torch.allclose(c, ans, atol=atol, rtol=rtol) assert torch.allclose(c.actual_tensor(), c.torch_tensor(), atol=atol, rtol=rtol)
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
profile_operation("PyTorch", lambda: mul(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation("PyTorch", lambda: mul(c.torch_tensor(), a.torch_tensor(), b.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_mul(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_mul(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
check_error(lib.infiniopDestroyMulDescriptor(descriptor)) check_error(LIBINFINIOP.infiniopDestroyMulDescriptor(descriptor))
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
lib = open_lib()
lib.infiniopCreateMulDescriptor.restype = c_int32
lib.infiniopCreateMulDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopMulDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetMulWorkspaceSize.restype = c_int32
lib.infiniopGetMulWorkspaceSize.argtypes = [
infiniopMulDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopMul.restype = c_int32
lib.infiniopMul.argtypes = [
infiniopMulDescriptor_t,
c_void_p,
c_uint64,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyMulDescriptor.restype = c_int32
lib.infiniopDestroyMulDescriptor.argtypes = [
infiniopMulDescriptor_t,
]
# Configure testing options # Configure testing options
DEBUG = args.debug DEBUG = args.debug
...@@ -240,7 +175,6 @@ if __name__ == "__main__": ...@@ -240,7 +175,6 @@ if __name__ == "__main__":
NUM_ITERATIONS = args.num_iterations NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args): for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
import torch import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float from ctypes import c_uint64
from libinfiniop import ( from libinfiniop import (
InfiniDtype, LIBINFINIOP,
infiniopHandle_t, TestTensor,
infiniopTensorDescriptor_t,
open_lib,
to_tensor,
get_test_devices, get_test_devices,
check_error, check_error,
create_workspace,
test_operator, test_operator,
get_args, get_args,
debug_all, debug_all,
get_tolerance, get_tolerance,
profile_operation, profile_operation,
synchronize_device, TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
) )
# ============================================================================== # ==============================================================================
...@@ -37,11 +37,11 @@ _TEST_CASES = [ ...@@ -37,11 +37,11 @@ _TEST_CASES = [
] ]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.bfloat16] _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16]
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 0}, InfiniDtype.F16: {"atol": 0, "rtol": 0},
torch.bfloat16: {"atol": 0, "rtol": 0}, InfiniDtype.BF16: {"atol": 0, "rtol": 0},
} }
...@@ -51,13 +51,6 @@ NUM_PRERUN = 10 ...@@ -51,13 +51,6 @@ NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
class RandomSampleDescriptor(Structure):
_fields_ = [("device", c_int32)]
infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
def random_sample(data, random_val, topp, topk, voc, temperature): def random_sample(data, random_val, topp, topk, voc, temperature):
if topp > 0 and topk > 1: if topp > 0 and topk > 1:
sorted_vals, sorted_indices = torch.sort(data, descending=True) sorted_vals, sorted_indices = torch.sort(data, descending=True)
...@@ -68,81 +61,81 @@ def random_sample(data, random_val, topp, topk, voc, temperature): ...@@ -68,81 +61,81 @@ def random_sample(data, random_val, topp, topk, voc, temperature):
k_index = min(topk, voc) - 1 k_index = min(topk, voc) - 1
threshold = min(cum_probs[k_index], topp) * random_val threshold = min(cum_probs[k_index], topp) * random_val
try: try:
idx = torch.searchsorted(cum_probs, threshold) idx = torch.searchsorted(cum_probs, threshold)
except Exception: except Exception:
# Fallback for manual search if torch.searchsorted is not supported # Fallback for manual search if torch.searchsorted is not supported
indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0] indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0]
idx = indices[0] if indices.numel() > 0 else torch.tensor(len(cum_probs)-1, device=cum_probs.device) idx = (
indices[0]
if indices.numel() > 0
else torch.tensor(len(cum_probs) - 1, device=cum_probs.device)
)
return sorted_indices[idx] return sorted_indices[idx]
return torch.argmax(data) return torch.argmax(data)
def test( def test(
lib,
handle, handle,
torch_device, device,
voc, voc,
random_val, random_val,
topp, topp,
topk, topk,
temperature, temperature,
dtype=torch.float16, dtype=InfiniDtype.F16,
sync=None, sync=None,
): ):
print( print(
f"Testing RandomSample on {torch_device} with voc:{voc} random_val:{random_val} topp:{topp} topk:{topk} temperature:{temperature} dtype:{dtype}" f"Testing RandomSample on {InfiniDeviceNames[device]} with voc:{voc} random_val:{random_val} topp:{topp} topk:{topk} temperature:{temperature} dtype:{InfiniDtypeNames[dtype]}"
) )
data = torch.arange(voc).float() * 0.0001
_perm = torch.randperm(voc) _perm = torch.randperm(voc)
data = data[_perm].to(dtype).to(torch_device) logits = TestTensor.from_torch(
torch.arange(voc)[_perm].float() * 0.0001, dtype, device
)
ans = random_sample( ans = random_sample(
data, random_val, topp, topk, voc, temperature logits.torch_tensor(), random_val, topp, topk, voc, temperature
) # 这个函数在device速度可能会很慢,可以通过data.to("cpu")方式加快计算过程 ) # 这个函数在device速度可能会很慢,可以通过data.to("cpu")方式加快计算过程
indices = torch.zeros([], dtype=torch.int64).to(torch_device) indices = TestTensor([], None, InfiniDtype.I32, device, mode="zeros")
x_tensor, indices_tensor = [to_tensor(tensor, lib) for tensor in [data, indices]]
indices_tensor.descriptor.contents.dt = InfiniDtype.U64 # treat int64 as uint64
if sync is not None: if sync is not None:
sync() sync()
descriptor = infiniopRandomSampleDescriptor_t() descriptor = infiniopOperatorDescriptor_t()
check_error( check_error(
lib.infiniopCreateRandomSampleDescriptor( LIBINFINIOP.infiniopCreateRandomSampleDescriptor(
handle, handle,
ctypes.byref(descriptor), ctypes.byref(descriptor),
indices_tensor.descriptor, indices.descriptor,
x_tensor.descriptor, logits.descriptor,
) )
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [x_tensor, indices_tensor]: for tensor in [logits, indices]:
tensor.destroyDesc(lib) tensor.destroy_desc()
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
lib.infiniopGetRandomSampleWorkspaceSize( LIBINFINIOP.infiniopGetRandomSampleWorkspaceSize(
descriptor, ctypes.byref(workspace_size) descriptor, ctypes.byref(workspace_size)
) )
) )
workspace = create_workspace(workspace_size.value, torch_device) workspace = TestWorkspace(workspace_size.value, device)
def lib_random_sample(): def lib_random_sample():
check_error( check_error(
lib.infiniopRandomSample( LIBINFINIOP.infiniopRandomSample(
descriptor, descriptor,
workspace.data_ptr() if workspace is not None else None, workspace.data(),
workspace_size.value, workspace_size.value,
indices_tensor.data, indices.data(),
x_tensor.data, logits.data(),
random_val, random_val,
topp, topp,
topk, topk,
...@@ -153,66 +146,36 @@ def test( ...@@ -153,66 +146,36 @@ def test(
lib_random_sample() lib_random_sample()
if torch_device == "npu": if sync is not None:
synchronize_device(torch_device) sync()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug_all( debug_all(
(indices.type(ans.dtype), data[indices]), (indices.actual_tensor(), logits.actual_tensor()[indices.actual_tensor()]),
(ans, data[ans]), (ans, logits.torch_tensor()[ans]),
"or", "or",
atol=atol, atol=atol,
rtol=rtol, rtol=rtol,
) )
assert indices.type(ans.dtype) == ans or data[ans] == data[indices] assert (
indices.actual_tensor() == ans
or logits.actual_tensor()[indices.actual_tensor()] == logits.torch_tensor()[ans]
)
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
profile_operation("PyTorch", lambda: random_sample( profile_operation("PyTorch", lambda: random_sample(
data, random_val, topp, topk, voc, temperature logits.torch_tensor(), random_val, topp, topk, voc, temperature
), torch_device, NUM_PRERUN, NUM_ITERATIONS) ), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_random_sample(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_random_sample(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor)) check_error(LIBINFINIOP.infiniopDestroyRandomSampleDescriptor(descriptor))
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
lib = open_lib()
lib.infiniopCreateRandomSampleDescriptor.restype = c_int32
lib.infiniopCreateRandomSampleDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopRandomSampleDescriptor_t),
infiniopTensorDescriptor_t,
]
lib.infiniopGetRandomSampleWorkspaceSize.restype = c_int32
lib.infiniopGetRandomSampleWorkspaceSize.argtypes = [
infiniopRandomSampleDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopRandomSample.restype = c_int32
lib.infiniopRandomSample.argtypes = [
infiniopRandomSampleDescriptor_t,
c_void_p,
c_uint64,
c_uint64,
c_void_p,
c_float,
c_float,
c_int32,
c_float,
c_void_p,
]
lib.infiniopDestroyRandomSampleDescriptor.restype = c_int32
lib.infiniopDestroyRandomSampleDescriptor.argtypes = [
infiniopRandomSampleDescriptor_t,
]
DEBUG = args.debug DEBUG = args.debug
PROFILE = args.profile PROFILE = args.profile
...@@ -221,6 +184,6 @@ if __name__ == "__main__": ...@@ -221,6 +184,6 @@ if __name__ == "__main__":
# Execute tests # Execute tests
for device in get_test_devices(args): for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
import torch import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_void_p
from libinfiniop import ( from libinfiniop import (
infiniopHandle_t, LIBINFINIOP,
infiniopTensorDescriptor_t, TestTensor,
open_lib,
to_tensor,
get_test_devices, get_test_devices,
check_error, check_error,
rearrange_if_needed,
rearrange_tensor,
test_operator, test_operator,
get_args, get_args,
debug, debug,
get_tolerance, get_tolerance,
profile_operation, profile_operation,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
) )
def row_major_strides(shape): def row_major_strides(shape):
"""生成张量的行优先(C风格)stride """生成张量的行优先(C风格)stride
Args: Args:
shape: 张量形状 shape: 张量形状
Returns: Returns:
行优先strides列表 行优先strides列表
""" """
...@@ -34,12 +34,13 @@ def row_major_strides(shape): ...@@ -34,12 +34,13 @@ def row_major_strides(shape):
strides.insert(0, stride) strides.insert(0, stride)
return strides return strides
def column_major_strides(shape): def column_major_strides(shape):
"""生成张量的列优先(Fortran风格)stride """生成张量的列优先(Fortran风格)stride
Args: Args:
shape: 张量形状 shape: 张量形状
Returns: Returns:
列优先strides列表 列优先strides列表
""" """
...@@ -52,62 +53,37 @@ def column_major_strides(shape): ...@@ -52,62 +53,37 @@ def column_major_strides(shape):
return strides return strides
# ============================================================================== # ==============================================================================
# Configuration (Internal Use Only) # Configuration (Internal Use Only)
# ============================================================================== # ==============================================================================
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES = [
# (shape, x_stride, y_stride) # (shape, x_stride, y_stride)
((100, 100), (1, 100), (100, 1)), # shape # x_stride # y_stride
((4, 4), (1, 4), (4, 1)), # shape # x_stride # y_stride
((4, 6, 64), (64, 4 * 64, 1), (6 * 64, 64, 1)), # shape # x_stride # y_stride
((2000, 2000), (1, 2000), (2000, 1)), # shape # x_stride # y_stride
((2001, 2001), (1, 2001), (2001, 1)), # shape # x_stride # y_stride
((2, 2, 2, 4), (16, 8, 4, 1), (16, 8, 1, 2)), # shape # x_stride # y_stride
( (
(100, 100), # shape (3, 4, 7, 53, 9), # shape
(1, 100), # x_stride row_major_strides((3, 4, 7, 53, 9)), # x_stride
(100, 1) # y_stride column_major_strides((3, 4, 7, 53, 9)), # y_stride
),
(
(4, 4), # shape
(1, 4), # x_stride
(4, 1) # y_stride
),
(
(4, 6, 64), # shape
(64, 4*64, 1), # x_stride
(6*64, 64, 1) # y_stride
),
(
(2000, 2000), # shape
(1, 2000), # x_stride
(2000, 1) # y_stride
), ),
( (
(2001, 2001), # shape (3, 4, 50, 50, 5, 7), # shape
(1, 2001), # x_stride
(2001, 1) # y_stride
),
(
(2, 2, 2, 4), # shape
(16, 8, 4, 1), # x_stride
(16, 8, 1, 2) # y_stride
),
(
(3, 4, 7, 53, 9), # shape
row_major_strides((3, 4, 7, 53, 9)), # x_stride
column_major_strides((3, 4, 7, 53, 9)) # y_stride
),
(
(3, 4, 50, 50, 5, 7), # shape
row_major_strides((3, 4, 50, 50, 5, 7)), # x_stride row_major_strides((3, 4, 50, 50, 5, 7)), # x_stride
column_major_strides((3, 4, 50, 50, 5, 7)) # y_stride column_major_strides((3, 4, 50, 50, 5, 7)), # y_stride
), ),
] ]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32] _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 0}, InfiniDtype.F16: {"atol": 0, "rtol": 0},
torch.float32: {"atol": 0, "rtol": 0}, InfiniDtype.F32: {"atol": 0, "rtol": 0},
} }
DEBUG = False DEBUG = False
...@@ -116,106 +92,60 @@ NUM_PRERUN = 10 ...@@ -116,106 +92,60 @@ NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
class RearrangeDescriptor(Structure): def rearrange_torch(y, x, x_shape, y_stride):
_fields_ = [("device", c_int32)] y.set_(y.untyped_storage(), 0, x_shape, y_stride)
y[:] = x.view_as(y)
infiniopRearrangeDescriptor_t = POINTER(RearrangeDescriptor)
def rearrange_torch(x, x_shape, y_stride):
y_ = x.clone()
y_.set_(y_.untyped_storage(), 0, x_shape, y_stride)
y_[:] = x.view_as(y_)
return y_
def test( def test(
lib, handle, torch_device, shape, x_stride, y_stride, dtype=InfiniDtype.F16, sync=None
handle,
torch_device,
shape,
x_stride,
y_stride,
dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing Rerrange on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{dtype}" f"Testing Rerrange on {InfiniDeviceNames[torch_device]} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{InfiniDtypeNames[dtype]}"
) )
x = torch.rand(shape, dtype=dtype).to(torch_device) x = TestTensor(shape, x_stride, dtype, device)
y = torch.zeros(shape, dtype=dtype).to(torch_device) y = TestTensor(shape, y_stride, dtype, device, mode="ones")
rearrange_torch(x, shape, y_stride) rearrange_torch(y.torch_tensor(), x.torch_tensor(), shape, y_stride)
x, y = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([x, y], [x_stride, y_stride])
]
x_tensor, y_tensor = [to_tensor(tensor, lib) for tensor in [x, y]]
if sync is not None: if sync is not None:
sync() sync()
descriptor = infiniopRearrangeDescriptor_t() descriptor = infiniopOperatorDescriptor_t()
check_error( check_error(
lib.infiniopCreateRearrangeDescriptor( LIBINFINIOP.infiniopCreateRearrangeDescriptor(
handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor handle, ctypes.byref(descriptor), y.descriptor, x.descriptor
) )
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [x_tensor, y_tensor]: for tensor in [x, y]:
tensor.destroyDesc(lib) tensor.destroy_desc()
def lib_rearrange(): def lib_rearrange():
check_error( check_error(LIBINFINIOP.infiniopRearrange(descriptor, y.data(), x.data(), None))
lib.infiniopRearrange(descriptor, y_tensor.data, x_tensor.data, None)
)
lib_rearrange() lib_rearrange()
# Validate results # Validate results
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug(x, y, atol=atol, rtol=rtol) debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
assert torch.allclose(x, y, atol=atol, rtol=rtol) assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
profile_operation("PyTorch", lambda: rearrange_torch(x, shape, y_stride), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation("PyTorch", lambda: rearrange_torch(y.torch_tensor(), x.torch_tensor(), shape, y_stride), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_rearrange(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_rearrange(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor)) check_error(LIBINFINIOP.infiniopDestroyRearrangeDescriptor(descriptor))
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
lib = open_lib()
lib.infiniopCreateRearrangeDescriptor.restype = c_int32
lib.infiniopCreateRearrangeDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopRearrangeDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopRearrange.restype = c_int32
lib.infiniopRearrange.argtypes = [
infiniopRearrangeDescriptor_t,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyRearrangeDescriptor.restype = c_int32
lib.infiniopDestroyRearrangeDescriptor.argtypes = [infiniopRearrangeDescriptor_t]
# Configure testing options # Configure testing options
DEBUG = args.debug DEBUG = args.debug
PROFILE = args.profile PROFILE = args.profile
...@@ -224,6 +154,6 @@ if __name__ == "__main__": ...@@ -224,6 +154,6 @@ if __name__ == "__main__":
# Execute tests # Execute tests
for device in get_test_devices(args): for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
import ctypes
import torch import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float from ctypes import c_uint64
from libinfiniop import ( from libinfiniop import (
infiniopHandle_t, LIBINFINIOP,
infiniopTensorDescriptor_t, TestTensor,
open_lib,
to_tensor,
get_test_devices, get_test_devices,
check_error, check_error,
rearrange_if_needed,
create_workspace,
test_operator, test_operator,
get_args, get_args,
debug, debug,
get_tolerance, get_tolerance,
profile_operation, profile_operation,
TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
) )
# ============================================================================== # ==============================================================================
...@@ -33,23 +32,21 @@ _TEST_CASES_ = [ ...@@ -33,23 +32,21 @@ _TEST_CASES_ = [
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)), ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)),
] ]
# w (weight) types # w (weight) types
# Note: 'None' means the same as input dtype # Note: 'None' means the same as input dtype
_WEIGHT_DTYPES = [None, torch.float32] _WEIGHT_DTYPES = [None, InfiniDtype.F32]
# x types used for testing # x types used for testing
_TENSOR_DTYPES = [torch.float16, torch.bfloat16] _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16]
# Form the test cases by appending each element of _WEIGHT_DTYPES to each tuple in _TEST_CASES_ # Form the test cases by appending each element of _WEIGHT_DTYPES to each tuple in _TEST_CASES_
_TEST_CASES = [ _TEST_CASES = [
test_case + (w_dtype,) test_case + (w_dtype,) for test_case in _TEST_CASES_ for w_dtype in _WEIGHT_DTYPES
for test_case in _TEST_CASES_
for w_dtype in _WEIGHT_DTYPES
] ]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 2e-3, "rtol": 2e-3}, InfiniDtype.F16: {"atol": 2e-3, "rtol": 2e-3},
torch.bfloat16: {"atol": 8e-3, "rtol": 8e-3}, InfiniDtype.BF16: {"atol": 8e-3, "rtol": 8e-3},
} }
DEBUG = False DEBUG = False
...@@ -58,13 +55,6 @@ NUM_PRERUN = 10 ...@@ -58,13 +55,6 @@ NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
class RMSNormDescriptor(Structure):
_fields_ = [("device", c_int32)]
infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)
def rms_norm(ans, x, w, eps): def rms_norm(ans, x, w, eps):
torch.pow(x, 2, out=ans) torch.pow(x, 2, out=ans)
mean = torch.mean(ans, dim=-1, keepdim=True) mean = torch.mean(ans, dim=-1, keepdim=True)
...@@ -75,73 +65,67 @@ def rms_norm(ans, x, w, eps): ...@@ -75,73 +65,67 @@ def rms_norm(ans, x, w, eps):
def test( def test(
lib,
handle, handle,
torch_device, device,
y_shape, y_shape,
x_shape, x_shape,
w_shape, w_shape,
y_stride, y_stride,
x_stride, x_stride,
w_dtype=torch.float16, w_dtype=InfiniDtype.F32,
dtype=torch.float16, dtype=InfiniDtype.F16,
sync=None, sync=None,
): ):
w_dtype = w_dtype if w_dtype else dtype
print( print(
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}" f"Testing RMS_Norm on {InfiniDeviceNames[device]} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
f" y_stride:{y_stride} x_stride:{x_stride} w_dtype:{w_dtype} dtype:{dtype}" f" y_stride:{y_stride} x_stride:{x_stride} w_dtype:{InfiniDtypeNames[w_dtype]} dtype:{InfiniDtypeNames[dtype]}"
) )
w_dtype = w_dtype if w_dtype else dtype y = TestTensor(y_shape, y_stride, dtype, device, mode="ones")
y = torch.zeros(y_shape, dtype=dtype).to(torch_device) x = TestTensor(x_shape, x_stride, dtype, device, scale=0.01)
x = torch.rand(x_shape, dtype=dtype).to(torch_device) w = TestTensor(w_shape, None, w_dtype, device)
w = torch.rand(w_shape, dtype=w_dtype).to(torch_device)
ans = torch.zeros(y_shape, dtype=dtype).to(torch_device)
eps = 1e-5
rms_norm(ans, x, w, eps)
x, y = [ eps = 1e-6
rearrange_if_needed(tensor, stride) rms_norm(y.torch_tensor(), x.torch_tensor(), w.torch_tensor(), eps)
for tensor, stride in zip([x, y], [x_stride, y_stride])
]
x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]]
if sync is not None: if sync is not None:
sync() sync()
descriptor = infiniopRMSNormDescriptor_t() descriptor = infiniopOperatorDescriptor_t()
check_error( check_error(
lib.infiniopCreateRMSNormDescriptor( LIBINFINIOP.infiniopCreateRMSNormDescriptor(
handle, handle,
ctypes.byref(descriptor), ctypes.byref(descriptor),
y_tensor.descriptor, y.descriptor,
x_tensor.descriptor, x.descriptor,
w_tensor.descriptor, w.descriptor,
eps, eps,
) )
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [x_tensor, y_tensor, w_tensor]: for tensor in [x, y, w]:
tensor.destroyDesc(lib) tensor.destroy_desc()
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
lib.infiniopGetRMSNormWorkspaceSize(descriptor, ctypes.byref(workspace_size)) LIBINFINIOP.infiniopGetRMSNormWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
) )
workspace = create_workspace(workspace_size.value, y.device) workspace = TestWorkspace(workspace_size.value, y.device)
def lib_rms_norm(): def lib_rms_norm():
check_error( check_error(
lib.infiniopRMSNorm( LIBINFINIOP.infiniopRMSNorm(
descriptor, descriptor,
workspace.data_ptr() if workspace is not None else None, workspace.data(),
workspace_size.value, workspace_size.value,
y_tensor.data, y.data(),
x_tensor.data, x.data(),
w_tensor.data, w.data(),
None, None,
) )
) )
...@@ -150,53 +134,20 @@ def test( ...@@ -150,53 +134,20 @@ def test(
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug(y, ans, atol=atol, rtol=rtol) debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
assert torch.allclose(y, ans, atol=atol, rtol=rtol) assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
profile_operation("PyTorch", lambda: rms_norm(ans, x, w, eps), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation("PyTorch", lambda: rms_norm(y.torch_tensor(), x.torch_tensor(), w.torch_tensor(), eps), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_rms_norm(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_rms_norm(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor)) check_error(LIBINFINIOP.infiniopDestroyRMSNormDescriptor(descriptor))
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
lib = open_lib()
lib.infiniopCreateRMSNormDescriptor.restype = c_int32
lib.infiniopCreateRMSNormDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopRMSNormDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_float,
]
lib.infiniopGetRMSNormWorkspaceSize.restype = c_int32
lib.infiniopGetRMSNormWorkspaceSize.argtypes = [
infiniopRMSNormDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopRMSNorm.restype = c_int32
lib.infiniopRMSNorm.argtypes = [
infiniopRMSNormDescriptor_t,
c_void_p,
c_uint64,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyRMSNormDescriptor.restype = c_int32
lib.infiniopDestroyRMSNormDescriptor.argtypes = [
infiniopRMSNormDescriptor_t,
]
# Configure testing options # Configure testing options
DEBUG = args.debug DEBUG = args.debug
...@@ -206,6 +157,6 @@ if __name__ == "__main__": ...@@ -206,6 +157,6 @@ if __name__ == "__main__":
# Execute tests # Execute tests
for device in get_test_devices(args): for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
import torch import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p from ctypes import c_uint64
from libinfiniop import ( from libinfiniop import (
infiniopHandle_t, LIBINFINIOP,
infiniopTensorDescriptor_t, TestTensor,
open_lib,
to_tensor,
get_test_devices, get_test_devices,
check_error, check_error,
rearrange_if_needed,
create_workspace,
test_operator, test_operator,
get_args, get_args,
debug, debug,
get_tolerance, get_tolerance,
profile_operation, profile_operation,
synchronize_device, TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceEnum,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
) )
from enum import Enum, auto from enum import Enum, auto
...@@ -35,13 +36,13 @@ _TEST_CASES_ = [ ...@@ -35,13 +36,13 @@ _TEST_CASES_ = [
] ]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.bfloat16, torch.float32] _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 1e-3, "rtol": 1e-2}, InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
torch.bfloat16: {"atol": 5e-3, "rtol": 5e-2}, InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2},
torch.float32: {"atol": 1e-4, "rtol": 1e-3}, InfiniDtype.F32: {"atol": 1e-4, "rtol": 1e-3},
} }
...@@ -67,14 +68,7 @@ NUM_PRERUN = 10 ...@@ -67,14 +68,7 @@ NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
class RoPEDescriptor(Structure): def rotary_embedding(ans, t, sin, cos, device):
_fields_ = [("device", c_int32)]
infiniopRoPEDescriptor_t = POINTER(RoPEDescriptor)
def rotary_embedding(t, sin, cos, torch_device):
dh = t.shape[2] dh = t.shape[2]
dt = t.dtype dt = t.dtype
assert dh % 2 == 0, "Embedding dimension must be even." assert dh % 2 == 0, "Embedding dimension must be even."
...@@ -82,7 +76,7 @@ def rotary_embedding(t, sin, cos, torch_device): ...@@ -82,7 +76,7 @@ def rotary_embedding(t, sin, cos, torch_device):
t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2] t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
cos = cos.unsqueeze(1) # [seq_len, 1, dh // 2] cos = cos.unsqueeze(1) # [seq_len, 1, dh // 2]
sin = sin.unsqueeze(1) # [seq_len, 1, dh // 2] sin = sin.unsqueeze(1) # [seq_len, 1, dh // 2]
if torch_device == "cpu": if device == InfiniDeviceEnum.CPU:
(t_even, t_odd, cos, sin) = ( (t_even, t_odd, cos, sin) = (
t_even.float(), t_even.float(),
t_odd.float(), t_odd.float(),
...@@ -93,26 +87,23 @@ def rotary_embedding(t, sin, cos, torch_device): ...@@ -93,26 +87,23 @@ def rotary_embedding(t, sin, cos, torch_device):
t_out_even = t_even * cos - t_odd * sin t_out_even = t_even * cos - t_odd * sin
t_out_odd = t_even * sin + t_odd * cos t_out_odd = t_even * sin + t_odd * cos
t_out = torch.empty_like(t) ans[..., 0::2] = t_out_even.to(dt)
t_out[..., 0::2] = t_out_even ans[..., 1::2] = t_out_odd.to(dt)
t_out[..., 1::2] = t_out_odd
return t_out.to(dt).to(torch_device)
def sin_cos_table(pos, dim, torch_device, theta, dtype): def sin_cos_table(pos, dim, device, theta, dtype):
assert dim % 2 == 0, "Embedding dimension must be even." assert dim % 2 == 0, "Embedding dimension must be even."
freqs = (1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))).to( freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
torch_device
)
angles = torch.outer(pos, freqs) angles = torch.outer(pos, freqs)
return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) return (
TestTensor.from_torch(torch.sin(angles), dtype, device),
TestTensor.from_torch(torch.cos(angles), dtype, device),
)
def test( def test(
lib,
handle, handle,
torch_device, device,
shape, shape,
x_strides=None, x_strides=None,
y_strides=None, y_strides=None,
...@@ -120,71 +111,71 @@ def test( ...@@ -120,71 +111,71 @@ def test(
dtype=torch.float32, dtype=torch.float32,
sync=None, sync=None,
): ):
x = TestTensor(shape, x_strides, dtype, device)
if inplace == Inplace.INPLACE_X: if inplace == Inplace.INPLACE_X:
y_strides = x_strides if x_strides != y_strides:
print( return
f"Testing Rotary Positional Embedding on {torch_device} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{dtype} inplace:{inplace}"
)
x = torch.rand(shape, dtype=dtype).to(torch_device)
x = rearrange_if_needed(x, x_strides)
if inplace == Inplace.INPLACE_X:
y = x y = x
else: else:
y = torch.rand(shape, dtype=dtype).to(torch_device) y = TestTensor(shape, y_strides, dtype, device)
y = rearrange_if_needed(y, y_strides)
print(
f"Testing Rotary Positional Embedding on {InfiniDeviceNames[device]} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}"
)
theta = 1e5 theta = 1e5
pos = torch.arange(0, x.shape[0], dtype=torch.int32).to(torch_device) pos = TestTensor.from_torch(torch.arange(0, x.shape[0]), InfiniDtype.I32, device)
sin_table, cos_table = sin_cos_table(pos, x.shape[2], x.device, theta, dtype) sin_table, cos_table = sin_cos_table(
pos.torch_tensor(), x.shape[2], x.device, theta, dtype
)
ans = rotary_embedding(x, sin_table, cos_table, torch_device) rotary_embedding(
y.torch_tensor(),
x.torch_tensor(),
sin_table.torch_tensor(),
cos_table.torch_tensor(),
device,
)
descriptor = infiniopRoPEDescriptor_t() descriptor = infiniopOperatorDescriptor_t()
x_tensor, pos_tensor, sin_table_tensor, cos_table_tensor = [
to_tensor(tensor, lib, force_unsigned=True)
for tensor in [x, pos, sin_table, cos_table]
]
if inplace == Inplace.INPLACE_X:
y_tensor = x_tensor
else:
y_tensor = to_tensor(y, lib)
if sync is not None: if sync is not None:
sync() sync()
check_error( check_error(
lib.infiniopCreateRoPEDescriptor( LIBINFINIOP.infiniopCreateRoPEDescriptor(
handle, handle,
ctypes.byref(descriptor), ctypes.byref(descriptor),
y_tensor.descriptor, y.descriptor,
x_tensor.descriptor, x.descriptor,
pos_tensor.descriptor, pos.descriptor,
sin_table_tensor.descriptor, sin_table.descriptor,
cos_table_tensor.descriptor, cos_table.descriptor,
) )
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [y_tensor, x_tensor, pos_tensor, sin_table_tensor, cos_table_tensor]: for tensor in [y, x, pos, sin_table, cos_table]:
tensor.destroyDesc(lib) tensor.destroy_desc()
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
lib.infiniopGetRoPEWorkspaceSize(descriptor, ctypes.byref(workspace_size)) LIBINFINIOP.infiniopGetRoPEWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
) )
workspace = create_workspace(workspace_size.value, x.device) workspace = TestWorkspace(workspace_size.value, x.device)
def lib_rope(): def lib_rope():
check_error( check_error(
lib.infiniopRoPE( LIBINFINIOP.infiniopRoPE(
descriptor, descriptor,
workspace.data_ptr() if workspace is not None else None, workspace.data(),
workspace_size.value, workspace_size.value,
y_tensor.data, y.data(),
x_tensor.data, x.data(),
pos_tensor.data, pos.data(),
sin_table_tensor.data, sin_table.data(),
cos_table_tensor.data, cos_table.data(),
None, None,
) )
) )
...@@ -196,60 +187,32 @@ def test( ...@@ -196,60 +187,32 @@ def test(
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug(y, ans, atol=atol, rtol=rtol) debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
assert torch.allclose(y, ans, atol=atol, rtol=rtol) assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
if PROFILE: if PROFILE:
profile_operation( profile_operation(
"PyTorch", "PyTorch",
lambda: rotary_embedding(x, sin_table, cos_table, torch_device), lambda: rotary_embedding(
torch_device, y.torch_tensor(),
x.torch_tensor(),
sin_table.torch_tensor(),
cos_table.torch_tensor(),
device,
),
device,
NUM_PRERUN, NUM_PRERUN,
NUM_ITERATIONS, NUM_ITERATIONS,
) )
profile_operation( profile_operation(
" lib", lambda: lib_rope(), torch_device, NUM_PRERUN, NUM_ITERATIONS " lib", lambda: lib_rope(), device, NUM_PRERUN, NUM_ITERATIONS
) )
check_error(lib.infiniopDestroyRoPEDescriptor(descriptor)) check_error(LIBINFINIOP.infiniopDestroyRoPEDescriptor(descriptor))
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
lib = open_lib()
lib.infiniopCreateRoPEDescriptor.restype = c_int32
lib.infiniopCreateRoPEDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopRoPEDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetRoPEWorkspaceSize.restype = c_int32
lib.infiniopGetRoPEWorkspaceSize.argtypes = [
infiniopRoPEDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopRoPE.restype = c_int32
lib.infiniopRoPE.argtypes = [
infiniopRoPEDescriptor_t,
c_void_p,
c_uint64,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyRoPEDescriptor.restype = c_int32
lib.infiniopDestroyRoPEDescriptor.argtypes = [
infiniopRoPEDescriptor_t,
]
# Configure testing options # Configure testing options
DEBUG = args.debug DEBUG = args.debug
...@@ -259,6 +222,6 @@ if __name__ == "__main__": ...@@ -259,6 +222,6 @@ if __name__ == "__main__":
# Execute tests # Execute tests
for device in get_test_devices(args): for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment