Unverified Commit 8d09630a authored by gongchensu's avatar gongchensu Committed by GitHub
Browse files

Merge branch 'demo131' into Issue/862

parents ab52dead 012df56c
......@@ -183,6 +183,7 @@ def func6_initialize_device_relationship():
_infinicore.Device.Type.QY, # 9 "cuda"
_infinicore.Device.Type.KUNLUN, # 7 "cuda"
_infinicore.Device.Type.HYGON, # 8 "cuda"
_infinicore.Device.Type.ALI, # 10 "cuda"
]
if True:
print("\n ---------- 测试 CPU")
......
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# y_shape, a_shape, b_shape, w_shape, y_stride, a_stride, b_stride
((1, 4), (1, 4), (1, 4), (4,), None, None, None),
((2, 4), (2, 4), (2, 4), (4,), None, None, None),
((2, 2, 4), (2, 2, 4), (2, 2, 4), (4,), None, None, None),
((2, 2, 4), (2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1), (12, 8, 1)),
((16, 2048), (16, 2048), (16, 2048), (2048,), None, None, None),
((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None),
(
(4, 4, 2048),
(4, 4, 2048),
(4, 4, 2048),
(2048,),
(2048, 8192, 1),
(2048, 8192, 1),
(2048, 8192, 1),
),
(
(4, 4, 2048),
(4, 4, 2048),
(4, 4, 2048),
(2048,),
(16384, 4096, 1),
(16384, 4096, 1),
(16384, 4096, 1),
),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((15, 8192), (15, 8192), (15, 8192), (8192,), None, None, None),
]
# w (weight) types
# Note: 'None' means the same as input dtype
_WEIGHT_DTYPES = [None, InfiniDtype.F32, InfiniDtype.F16, InfiniDtype.BF16]
# a, b types used for testing
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16]
# Form the test cases by appending each element of _WEIGHT_DTYPES to each tuple in _TEST_CASES_
_TEST_CASES = [
test_case + (w_dtype,) for test_case in _TEST_CASES_ for w_dtype in _WEIGHT_DTYPES
]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 2e-3, "rtol": 2e-3},
InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
def add_rms_norm(ans, a, b, w, eps):
input_dtype = a.dtype
# Compute add(a, b)
sum_tensor = a.to(torch.float32) + b.to(torch.float32)
# Compute RMS normalization
scale = sum_tensor.pow(2).mean(-1, keepdim=True).add_(eps).rsqrt_()
ans.set_((sum_tensor.mul_(scale).mul_(w.to(torch.float32))).to(input_dtype))
def test(
handle,
device,
y_shape,
a_shape,
b_shape,
w_shape,
y_stride,
a_stride,
b_stride,
w_dtype=InfiniDtype.F32,
dtype=InfiniDtype.F16,
sync=None,
):
w_dtype = w_dtype if w_dtype else dtype
print(
f"Testing AddRMSNorm on {InfiniDeviceNames[device]} with y_shape:{y_shape} a_shape:{a_shape} b_shape:{b_shape} w_shape:{w_shape}"
f" y_stride:{y_stride} a_stride:{a_stride} b_stride:{b_stride} w_dtype:{InfiniDtypeNames[w_dtype]} dtype:{InfiniDtypeNames[dtype]}"
)
y = TestTensor(y_shape, y_stride, dtype, device, mode="ones")
residual_out = TestTensor(a_shape, a_stride, dtype, device, mode="ones")
a = TestTensor(a_shape, a_stride, dtype, device, scale=0.01)
b = TestTensor(b_shape, b_stride, dtype, device, scale=0.01)
w = TestTensor(w_shape, None, w_dtype, device)
eps = 1e-6
add_rms_norm(
y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps
)
if sync is not None:
sync()
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreateAddRMSNormDescriptor(
handle,
ctypes.byref(descriptor),
y.descriptor,
residual_out.descriptor,
a.descriptor,
b.descriptor,
w.descriptor,
eps,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [a, b, y, w, residual_out]:
tensor.destroy_desc()
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetAddRMSNormWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, y.device)
def lib_add_rms_norm():
check_error(
LIBINFINIOP.infiniopAddRMSNorm(
descriptor,
workspace.data(),
workspace_size.value,
y.data(),
residual_out.data(),
a.data(),
b.data(),
w.data(),
None,
)
)
lib_add_rms_norm()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
# Verify normalized result (y)
if DEBUG:
debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
# Verify add result (residual_out) - should be a + b
expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to(
torch.float32
)
expected_residual = expected_residual.to(a.torch_tensor().dtype)
if DEBUG:
debug(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol)
assert torch.allclose(
residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol
)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: add_rms_norm(y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_add_rms_norm(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(LIBINFINIOP.infiniopDestroyAddRMSNormDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
# Execute tests
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
......@@ -9,6 +9,7 @@ class InfiniDeviceEnum:
KUNLUN = 7
HYGON = 8
QY = 9
ALI = 10
InfiniDeviceNames = {
......@@ -22,6 +23,7 @@ InfiniDeviceNames = {
InfiniDeviceEnum.KUNLUN: "Kunlun",
InfiniDeviceEnum.HYGON: "Hygon",
InfiniDeviceEnum.QY: "QY",
InfiniDeviceEnum.ALI: "Ali",
}
# Mapping that maps InfiniDeviceEnum to torch device string
......@@ -36,4 +38,5 @@ torch_device_map = {
InfiniDeviceEnum.KUNLUN: "cuda",
InfiniDeviceEnum.HYGON: "cuda",
InfiniDeviceEnum.QY: "cuda",
InfiniDeviceEnum.ALI: "cuda",
}
......@@ -383,6 +383,45 @@ def rms_norm_(lib):
]
@OpRegister.operator
def add_rms_norm_(lib):
lib.infiniopCreateAddRMSNormDescriptor.restype = c_int32
lib.infiniopCreateAddRMSNormDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_float,
]
lib.infiniopGetAddRMSNormWorkspaceSize.restype = c_int32
lib.infiniopGetAddRMSNormWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopAddRMSNorm.restype = c_int32
lib.infiniopAddRMSNorm.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyAddRMSNormDescriptor.restype = c_int32
lib.infiniopDestroyAddRMSNormDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def rope_(lib):
lib.infiniopCreateRoPEDescriptor.restype = c_int32
......@@ -686,6 +725,41 @@ def dequantize_(lib):
]
@OpRegister.operator
def per_channel_quant_int8_(lib):
lib.infiniopCreatePerChannelQuantI8Descriptor.restype = c_int32
lib.infiniopCreatePerChannelQuantI8Descriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetPerChannelQuantI8WorkspaceSize.restype = c_int32
lib.infiniopGetPerChannelQuantI8WorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopPerChannelQuantI8.restype = c_int32
lib.infiniopPerChannelQuantI8.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyPerChannelQuantI8Descriptor.restype = c_int32
lib.infiniopDestroyPerChannelQuantI8Descriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def softplus_(lib):
lib.infiniopCreateSoftplusDescriptor.restype = c_int32
......@@ -938,3 +1012,204 @@ def tanh_(lib):
lib.infiniopDestroyTanhDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def scaled_mm_int8_(lib):
lib.infiniopCreateI8GemmDescriptor.restype = c_int32
lib.infiniopCreateI8GemmDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetI8GemmWorkspaceSize.restype = c_int32
lib.infiniopGetI8GemmWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopI8Gemm.restype = c_int32
lib.infiniopI8Gemm.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyI8GemmDescriptor.restype = c_int32
lib.infiniopDestroyI8GemmDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def paged_attention_(lib):
lib.infiniopCreatePagedAttentionDescriptor.restype = c_int32
lib.infiniopCreatePagedAttentionDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_void_p,
c_float,
]
lib.infiniopGetPagedAttentionWorkspaceSize.restype = c_int32
lib.infiniopGetPagedAttentionWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopPagedAttention.restype = c_int32
lib.infiniopPagedAttention.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyPagedAttentionDescriptor.restype = c_int32
lib.infiniopDestroyPagedAttentionDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def paged_caching_(lib):
lib.infiniopCreatePagedCachingDescriptor.restype = c_int32
lib.infiniopCreatePagedCachingDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t, # k_cache_desc
infiniopTensorDescriptor_t, # v_cache_desc
infiniopTensorDescriptor_t, # k_desc
infiniopTensorDescriptor_t, # v_desc
infiniopTensorDescriptor_t, # slot_mapping_desc
]
# infiniopGetPagedCachingWorkspaceSize
lib.infiniopGetPagedCachingWorkspaceSize.restype = c_int32
lib.infiniopGetPagedCachingWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
# infiniopPagedCaching
lib.infiniopPagedCaching.restype = c_int32
lib.infiniopPagedCaching.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p, # workspace
c_size_t, # workspace_size
c_void_p, # k_cache
c_void_p, # v_cache
c_void_p, # k
c_void_p, # v
c_void_p, # slot_mapping
c_void_p, # stream
]
# infiniopDestroyPagedCachingDescriptor
lib.infiniopDestroyPagedCachingDescriptor.restype = c_int32
lib.infiniopDestroyPagedCachingDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def paged_attention_prefill_(lib):
lib.infiniopCreatePagedAttentionPrefillDescriptor.restype = c_int32
lib.infiniopCreatePagedAttentionPrefillDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_float,
]
lib.infiniopGetPagedAttentionPrefillWorkspaceSize.restype = c_int32
lib.infiniopGetPagedAttentionPrefillWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopPagedAttentionPrefill.restype = c_int32
lib.infiniopPagedAttentionPrefill.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyPagedAttentionPrefillDescriptor.restype = c_int32
lib.infiniopDestroyPagedAttentionPrefillDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def silu_and_mul(lib):
lib.infiniopCreateSiluAndMulDescriptor.restype = c_int32
lib.infiniopCreateSiluAndMulDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetSiluAndMulWorkspaceSize.restype = c_int32
lib.infiniopGetSiluAndMulWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopSiluAndMul.restype = c_int32
lib.infiniopSiluAndMul.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroySiluAndMulDescriptor.restype = c_int32
lib.infiniopDestroySiluAndMulDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
......@@ -336,7 +336,7 @@ def rearrange_tensor(tensor, new_strides):
torch.float32,
torch.float64,
]:
new_tensor.view(-1).index_add_(0, new_positions, tensor.view(-1))
new_tensor.view(-1).index_add_(0, new_positions, tensor.contiguous().view(-1))
elif tensor.dtype in [torch.uint16, torch.uint32, torch.uint64]:
new_tensor_int64 = new_tensor.to(dtype=torch.int64)
tensor_int64 = tensor.to(dtype=torch.int64)
......@@ -433,6 +433,11 @@ def get_args():
action="store_true",
help="Run HYGON DCU test",
)
parser.add_argument(
"--ali",
action="store_true",
help="Run ALI PPU test",
)
return parser.parse_args()
......@@ -487,6 +492,7 @@ def filter_tensor_dtypes_by_device(device, tensor_dtypes):
InfiniDeviceEnum.ASCEND,
InfiniDeviceEnum.ILUVATAR,
InfiniDeviceEnum.CAMBRICON,
InfiniDeviceEnum.ALI,
):
return tensor_dtypes
else:
......@@ -757,6 +763,10 @@ def get_test_devices(args):
import torch
devices_to_test.append(InfiniDeviceEnum.HYGON)
if args.ali:
import torch
devices_to_test.append(InfiniDeviceEnum.ALI)
if not devices_to_test:
devices_to_test = [InfiniDeviceEnum.CPU]
......
import torch
import ctypes
from ctypes import c_uint64
import math
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
TestWorkspace,
)
# ==============================================================================
# Reference Implementation
# ==============================================================================
def get_alibi_slopes(n):
# 简化版的ALiBi斜率计算方法
# 参考: https://github.com/ofirpress/attention_with_linear_biases/blob/master/fairseq/models/transformer.py#L742
closest_power_of_2 = 2 ** math.floor(math.log2(n))
base = 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3)))
powers = [base**i for i in range(1, closest_power_of_2 + 1)]
if n > closest_power_of_2:
extra = [base ** (i * 2) for i in range(1, 2 * (n - closest_power_of_2) + 1, 2)]
powers += extra
return powers[:n]
def ref_masked_attention(query, key, value, scale, attn_mask=None):
# Reference implementation for a single masked attention head.
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
if attn_mask is not None:
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out
def ref_single_query_cached_kv_attention(
query, key_cache, value_cache, block_tables, seq_lens, scale, alibi_slopes
):
# Reference implementation for paged attention, iterating through each sequence.
output = torch.empty_like(query)
num_query_heads, num_kv_heads = query.shape[1], value_cache.shape[1]
num_queries_per_kv = num_query_heads // num_kv_heads
head_size, block_size = value_cache.shape[3], value_cache.shape[2]
num_seqs = query.shape[0]
for i in range(num_seqs):
q = query[i].unsqueeze(0)
seq_len = seq_lens[i].item()
block_table = block_tables[i]
keys_lst, values_lst = [], []
for j in range(seq_len):
block_num = block_table[j // block_size].item()
block_off = j % block_size
k = key_cache[block_num, :, block_off, :]
v = value_cache[block_num, :, block_off, :]
keys_lst.append(k)
values_lst.append(v)
keys = torch.stack(keys_lst, dim=0)
values = torch.stack(values_lst, dim=0)
if num_queries_per_kv > 1:
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
alibi_bias = None
if alibi_slopes is not None:
pos = torch.arange(seq_len, device=query.device).int()
alibi_bias = (pos - seq_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
output[i] = out.view(num_query_heads, head_size)
return output
# ==============================================================================
# Test Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_seq_len, use_alibi)
(1, 1, 1, 128, 16, 1024, False),
(4, 40, 40, 128, 16, 1024, False),
(6, 40, 40, 128, 16, 1024, False),
(3, 8, 8, 128, 16, 1024, False),
(3, 8, 8, 64, 16, 1024, False),
(8, 64, 8, 128, 16, 2048, False),
]
# Data types for testing
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2},
}
# Global flags for controlling test behavior
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
def test(
handle,
device,
num_seqs,
num_heads,
num_kv_heads,
head_size,
block_size,
max_seq_len,
use_alibi,
dtype=InfiniDtype.F16,
sync=None,
):
print(
f"Testing PagedAttention on {InfiniDeviceNames[device]} with "
f"num_seqs={num_seqs}, num_heads={num_heads}, head_size={head_size}, "
f"block_size={block_size}, dtype={InfiniDtypeNames[dtype]}, use_alibi={use_alibi}"
)
scale = 1.0 / (head_size**0.5)
max_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
num_blocks = num_seqs * max_blocks_per_seq # A reasonable number for testing
# Create input tensors
q = TestTensor((num_seqs, num_heads, head_size), None, dtype, device)
out = TestTensor((num_seqs, num_heads, head_size), None, dtype, device)
k_cache = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
v_cache = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
seq_lens_torch = torch.randint(1, max_seq_len, (num_seqs,), dtype=torch.int64)
seq_lens = TestTensor.from_torch(seq_lens_torch, InfiniDtype.I64, device)
block_tables_py = torch.arange(
0, num_seqs * max_blocks_per_seq, dtype=torch.int64
).view(num_seqs, max_blocks_per_seq)
block_tables = TestTensor.from_torch(block_tables_py, InfiniDtype.I64, device)
alibi_slopes_desc = ctypes.c_void_p(0)
alibi_slopes_data = ctypes.c_void_p(0)
alibi_slopes_torch = None
if use_alibi:
alibi_slopes = TestTensor((num_heads,), None, InfiniDtype.F32, device)
alibi_slopes_desc = alibi_slopes.descriptor
alibi_slopes_data = alibi_slopes.data()
alibi_slopes_torch = alibi_slopes.torch_tensor()
# Run reference implementation
ans = ref_single_query_cached_kv_attention(
q.torch_tensor(),
k_cache.torch_tensor(),
v_cache.torch_tensor(),
block_tables.torch_tensor(),
seq_lens.torch_tensor(),
scale,
alibi_slopes_torch,
)
if sync:
sync()
scale = 1.0 / (head_size**0.5)
# Create operator descriptor
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreatePagedAttentionDescriptor(
handle,
ctypes.byref(descriptor),
out.descriptor,
q.descriptor,
k_cache.descriptor,
v_cache.descriptor,
block_tables.descriptor,
seq_lens.descriptor,
alibi_slopes_desc,
scale,
)
)
# Get workspace size and allocate memory
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetPagedAttentionWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, q.device)
# Invalidate descriptors to ensure kernel does not rely on them
q.destroy_desc()
out.destroy_desc()
k_cache.destroy_desc()
v_cache.destroy_desc()
block_tables.destroy_desc()
seq_lens.destroy_desc()
if use_alibi:
alibi_slopes.destroy_desc()
# Define the library call as a lambda for profiling
def lib_paged_attention():
check_error(
LIBINFINIOP.infiniopPagedAttention(
descriptor,
workspace.data(),
workspace_size.value,
out.data(),
q.data(),
k_cache.data(),
v_cache.data(),
block_tables.data(),
seq_lens.data(),
alibi_slopes_data,
None,
)
)
# Execute the custom operator
lib_paged_attention()
if sync:
sync()
# Verify correctness
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(out.actual_tensor(), ans, atol=atol, rtol=rtol)
assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: ref_single_query_cached_kv_attention(
q.torch_tensor(), k_cache.torch_tensor(), v_cache.torch_tensor(),
block_tables.torch_tensor(), seq_lens.torch_tensor(),
scale, alibi_slopes_torch),
device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lib_paged_attention, device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
# Clean up resources
check_error(LIBINFINIOP.infiniopDestroyPagedAttentionDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
# Configure testing options from command line arguments
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
import ctypes
from ctypes import c_uint64
import torch
from libinfiniop import (
LIBINFINIOP,
InfiniDeviceNames,
InfiniDtype,
InfiniDtypeNames,
TestTensor,
TestWorkspace,
check_error,
debug,
get_args,
get_test_devices,
get_tolerance,
infiniopOperatorDescriptor_t,
profile_operation,
test_operator,
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES = [
# num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds
(1, 1, 1, 128, 8, 16, 1),
(1, 4, 4, 128, 8, 16, 4),
(2, 8, 8, 128, 16, 32, 2),
(4, 16, 16, 128, 8, 64, 3),
(8, 64, 64, 128, 8, 16, 5),
(16, 128, 128, 128, 8, 16, 4),
]
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16]
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-2, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 2e-2, "rtol": 2e-2},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 5
NUM_ITERATIONS = 10
# ==============================================================================
# Helper Classes & Reference Implementation
# ==============================================================================
class SimpleCacheManager:
def __init__(self, num_blocks, block_size):
self.num_blocks = num_blocks
self.block_size = block_size
self.free_blocks = list(range(num_blocks))
self.request_to_blocks = {}
self.request_to_len = {}
def allocate_slots(self, request_id, num_new_tokens):
if request_id not in self.request_to_len:
self.request_to_len[request_id] = 0
self.request_to_blocks[request_id] = []
start_pos = self.request_to_len[request_id]
new_total_len = start_pos + num_new_tokens
needed_blocks = (new_total_len + self.block_size - 1) // self.block_size
added_blocks = needed_blocks - len(self.request_to_blocks[request_id])
for _ in range(added_blocks):
self.request_to_blocks[request_id].append(self.free_blocks.pop(0))
self.request_to_len[request_id] = new_total_len
return self.request_to_blocks[request_id], new_total_len
def ref_paged_attention_multi_turn(
query_new, k_cache, v_cache, block_tables, seq_lens, cum_seq_lens_q, scale
):
block_size = k_cache.shape[2]
outputs = torch.zeros_like(query_new)
num_seqs = len(cum_seq_lens_q) - 1
for i in range(num_seqs):
num_new = cum_seq_lens_q[i + 1].item() - cum_seq_lens_q[i].item()
total_len = seq_lens[i].item()
cache_len = seq_lens[i].item() - num_new
table = block_tables[i]
keys_all, values_all = [], []
for j in range(total_len):
b_id = table[j // block_size].item()
off = j % block_size
keys_all.append(k_cache[b_id, :, off, :])
values_all.append(v_cache[b_id, :, off, :])
K = torch.stack(keys_all, dim=0)
V = torch.stack(values_all, dim=0)
Q = query_new[cum_seq_lens_q[i] : cum_seq_lens_q[i + 1], :, :]
scores = torch.einsum("qhd,khd->hqk", Q, K).float() * scale
mask = torch.full((num_new, total_len), float("-inf"), device=Q.device)
for q_idx in range(num_new):
mask[q_idx, : cache_len + q_idx + 1] = 0.0
scores = scores + mask.unsqueeze(0)
attn_weights = torch.softmax(scores, dim=-1).to(Q.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, V)
outputs[cum_seq_lens_q[i] : cum_seq_lens_q[i + 1], :, :] = out
return outputs
# ==============================================================================
# Test Operator Implementation
# ==============================================================================
def test(
handle,
device,
num_seqs,
num_heads,
num_kv_heads,
head_size,
block_size,
max_step_len,
num_rounds,
dtype=InfiniDtype.F16,
sync=None,
):
print(
f"Testing PagedAttentionPrefill on {InfiniDeviceNames[device]} with "
f"seqs:{num_seqs}, heads:{num_heads}, head_size:{head_size}, "
f"block:{block_size}, max_step_len:{max_step_len}, num_rounds:{num_rounds}, dtype:{InfiniDtypeNames[dtype]}"
)
# 1. Initialize persistent resources
num_blocks = 8192
manager = SimpleCacheManager(num_blocks, block_size)
scale = head_size**-0.5
k_cache = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
v_cache = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
# Multi-turn testing loop
for r in range(num_rounds):
# Prepare dynamic inputs for this round
query_lens_cpu = torch.randint(
1, max_step_len + 1, (num_seqs,), dtype=torch.int64
)
q_total_tokens = query_lens_cpu.sum().item()
q_packed_tensors = torch.zeros(q_total_tokens, num_heads, head_size)
seq_lens_list = []
all_block_tables = []
cum_seq_lens_q_list = []
cum_q_lens = 0
for i in range(num_seqs):
cum_seq_lens_q_list.append(cum_q_lens)
cur_q_len = query_lens_cpu[i].item()
table, total_len = manager.allocate_slots(i, cur_q_len)
cur_seq_lens = total_len - cur_q_len
seq_lens_list.append(total_len)
all_block_tables.append(table)
# Simulated KV insertion
k_new = torch.randn(cur_q_len, num_kv_heads, head_size)
v_new = torch.randn(cur_q_len, num_kv_heads, head_size)
q_val = torch.randn(cur_q_len, num_heads, head_size)
q_packed_tensors[cum_q_lens : cum_q_lens + cur_q_len] = q_val
cum_q_lens = cum_q_lens + cur_q_len
for t in range(cur_q_len):
logical_pos = cur_seq_lens + t
b_id = table[logical_pos // block_size]
off = logical_pos % block_size
k_cache.torch_tensor()[b_id, :, off, :] = k_new[t]
v_cache.torch_tensor()[b_id, :, off, :] = v_new[t]
cum_seq_lens_q_list.append(cum_q_lens)
k_cache.actual_tensor().copy_(k_cache._torch_tensor)
v_cache.actual_tensor().copy_(v_cache._torch_tensor)
# 2. Wrap tensors for Infiniop
q_new = TestTensor.from_torch(q_packed_tensors, dtype, device)
out = TestTensor.from_torch(q_packed_tensors, dtype, device)
out.actual_tensor().zero_()
seq_lens = TestTensor.from_torch(
torch.tensor(seq_lens_list, dtype=torch.int64), InfiniDtype.I64, device
)
cum_seq_lens_q = TestTensor.from_torch(
torch.tensor(cum_seq_lens_q_list, dtype=torch.int64),
InfiniDtype.I64,
device,
)
max_blocks = max(len(t) for t in all_block_tables)
padded_tables = [t + [0] * (max_blocks - len(t)) for t in all_block_tables]
block_tables = TestTensor.from_torch(
torch.tensor(padded_tables, dtype=torch.int64), InfiniDtype.I64, device
)
# 3. Reference Calculation
def torch_paged_attention_multi_turn():
return ref_paged_attention_multi_turn(
q_new.torch_tensor(),
k_cache.torch_tensor(),
v_cache.torch_tensor(),
block_tables.torch_tensor(),
seq_lens.torch_tensor(),
cum_seq_lens_q.torch_tensor(),
scale,
)
ans = torch_paged_attention_multi_turn()
# 4. Infiniop Operator Execution
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreatePagedAttentionPrefillDescriptor(
handle,
ctypes.byref(descriptor),
out.descriptor,
q_new.descriptor,
k_cache.descriptor,
v_cache.descriptor,
block_tables.descriptor,
seq_lens.descriptor,
cum_seq_lens_q.descriptor,
None,
scale,
)
)
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetPagedAttentionPrefillWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, device)
def lib_attn():
check_error(
LIBINFINIOP.infiniopPagedAttentionPrefill(
descriptor,
workspace.data(),
workspace_size.value,
out.data(),
q_new.data(),
k_cache.data(),
v_cache.data(),
block_tables.data(),
seq_lens.data(),
cum_seq_lens_q.data(),
None,
None,
)
)
lib_attn()
if sync:
sync()
# 5. Validation
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(out.actual_tensor(), ans, atol=atol, rtol=rtol)
assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol)
# Profiling
if PROFILE:
profile_operation(
f"Torch_R{r}",
lambda: torch_paged_attention_multi_turn(),
device,
NUM_PRERUN,
NUM_ITERATIONS,
)
profile_operation(
f" Lib_R{r}", lambda: lib_attn(), device, NUM_PRERUN, NUM_ITERATIONS
)
check_error(
LIBINFINIOP.infiniopDestroyPagedAttentionPrefillDescriptor(descriptor)
)
# ==============================================================================
# Main Execution
# ==============================================================================
if __name__ == "__main__":
args = get_args()
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
TestWorkspace,
)
# ==============================================================================
# Reference Implementation
# ==============================================================================
def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping):
"""
Reference implementation for paged_caching operator.
Args:
key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh]
value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh]
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
"""
ntok = key.shape[0]
block_size = key_cache_pool.shape[2]
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# mimicking the behavior where the custom operator writes to its output tensor.
k_cache_ref = key_cache_pool.clone()
v_cache_ref = value_cache_pool.clone()
for i in range(ntok):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
key_token = key[i]
value_token = value[i]
k_cache_ref[block_idx, :, block_offset, :] = key_token
v_cache_ref[block_idx, :, block_offset, :] = value_token
return k_cache_ref, v_cache_ref
# ==============================================================================
# Test Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES_ = [
# (num_seqs, max_seq_len, num_kv_heads, head_size, block_size)
(1, 128, 8, 128, 16),
(5, 512, 40, 128, 16),
(16, 1024, 8, 64, 32),
(10, 1024, 40, 64, 32),
]
# Data types for testing
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 0, "rtol": 1e-5},
InfiniDtype.BF16: {"atol": 0, "rtol": 1e-5},
InfiniDtype.F32: {"atol": 0, "rtol": 1e-5},
}
# Global flags for controlling test behavior
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 100
def test(
handle,
device,
num_seqs, # nreq
max_seq_len,
num_kv_heads, # nkvh
head_size, # dh
block_size,
dtype=InfiniDtype.F16,
sync=None,
):
print(
f"Testing PagedCaching on {InfiniDeviceNames[device]} with "
f"num_seqs={num_seqs}, max_seq_len={max_seq_len}, num_kv_heads={num_kv_heads}, "
f"head_size={head_size}, block_size={block_size}, dtype={InfiniDtypeNames[dtype]}"
)
num_blocks = 4096 # A reasonably large cache pool for testing
# Create metadata: variable context lengths for each sequence in the batch
context_lens_torch = torch.randint(
1, max_seq_len + 1, (num_seqs,), dtype=torch.int64
)
ntok = torch.sum(context_lens_torch).item()
# If ntok is 0 (all sequences have length 0), skip the test
if ntok == 0:
print("Skipping test case with ntok=0")
return
# Simulate the scheduler's behavior to create the slot_mapping
slot_mapping_list = []
current_slot = 0
for length in context_lens_torch:
# Find a contiguous chunk of 'length' slots
start_slot = current_slot
slot_mapping_list.extend(range(start_slot, start_slot + length.item()))
current_slot += length.item()
# Ensure we don't exceed the total number of slots in the cache
assert current_slot <= num_blocks * block_size, (
"Not enough blocks in the cache pool for this test case"
)
slot_mapping_torch = torch.tensor(slot_mapping_list, dtype=torch.int64)
# Create input tensors based on the calculated total tokens (ntok)
k = TestTensor((ntok, num_kv_heads, head_size), None, dtype, device)
v = TestTensor((ntok, num_kv_heads, head_size), None, dtype, device)
slot_mapping = TestTensor.from_torch(slot_mapping_torch, InfiniDtype.I64, device)
# The cache pools are the "output" tensors for this operator
k_cache_pool = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
v_cache_pool = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
# Run reference implementation
k_cache_ref, v_cache_ref = ref_paged_caching(
k_cache_pool.torch_tensor(),
v_cache_pool.torch_tensor(),
k.torch_tensor(),
v.torch_tensor(),
slot_mapping.torch_tensor(),
)
if sync:
sync()
# Create operator descriptor
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreatePagedCachingDescriptor(
handle,
ctypes.byref(descriptor),
k_cache_pool.descriptor,
v_cache_pool.descriptor,
k.descriptor,
v.descriptor,
slot_mapping.descriptor,
)
)
# Get workspace size (likely 0 for this operator, but good practice to include)
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetPagedCachingWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, device)
# Invalidate descriptors to ensure kernel does not rely on them
k.destroy_desc()
v.destroy_desc()
k_cache_pool.destroy_desc()
v_cache_pool.destroy_desc()
slot_mapping.destroy_desc()
# Define the library call as a lambda for profiling
def lib_paged_caching():
check_error(
LIBINFINIOP.infiniopPagedCaching(
descriptor,
workspace.data(),
workspace_size.value,
k_cache_pool.data(),
v_cache_pool.data(),
k.data(),
v.data(),
slot_mapping.data(),
None,
)
)
# Execute the custom operator
lib_paged_caching()
if sync:
sync()
# Verify correctness
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
print("Verifying K cache...")
debug(k_cache_pool.actual_tensor(), k_cache_ref, atol=atol, rtol=rtol)
print("Verifying V cache...")
debug(v_cache_pool.actual_tensor(), v_cache_ref, atol=atol, rtol=rtol)
assert torch.allclose(
k_cache_pool.actual_tensor(), k_cache_ref, atol=atol, rtol=rtol
)
assert torch.allclose(
v_cache_pool.actual_tensor(), v_cache_ref, atol=atol, rtol=rtol
)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: ref_paged_caching(
k.torch_tensor(), v.torch_tensor(),
k_cache_pool.torch_tensor(), v_cache_pool.torch_tensor(),
slot_mapping.torch_tensor()),
device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lib_paged_caching, device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
# Clean up resources
check_error(LIBINFINIOP.infiniopDestroyPagedCachingDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
# Configure testing options from command line arguments
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES = [
# num_seqs, max_step_len, num_kv_heads, head_size, block_size, num_rounds
(1, 16, 1, 128, 8, 5),
(2, 64, 8, 128, 16, 2),
(8, 128, 32, 128, 16, 3),
(5, 512, 40, 128, 16, 3),
(16, 64, 8, 128, 32, 1),
(10, 256, 40, 128, 32, 3),
]
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
_TOLERANCE_MAP = {
InfiniDtype.F32: {"atol": 1e-8, "rtol": 1e-8},
InfiniDtype.F16: {"atol": 1e-8, "rtol": 1e-8},
InfiniDtype.BF16: {"atol": 1e-8, "rtol": 1e-8},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 5
NUM_ITERATIONS = 10
# ==============================================================================
# Helper Classes & Reference Implementation
# ==============================================================================
class SimpleCacheManager:
def __init__(self, num_blocks, block_size):
self.num_blocks = num_blocks
self.block_size = block_size
self.free_blocks = list(range(num_blocks))
self.request_to_blocks = {}
self.request_to_len = {}
def allocate_slots(self, request_id, num_new_tokens):
if request_id not in self.request_to_len:
self.request_to_len[request_id] = 0
self.request_to_blocks[request_id] = []
start_pos = self.request_to_len[request_id]
new_total_len = start_pos + num_new_tokens
needed_blocks = (new_total_len + self.block_size - 1) // self.block_size
added_blocks = needed_blocks - len(self.request_to_blocks[request_id])
for _ in range(added_blocks):
self.request_to_blocks[request_id].append(self.free_blocks.pop(0))
slots = []
for i in range(start_pos, new_total_len):
block_idx_in_seq = i // self.block_size
block_offset = i % self.block_size
physical_block_id = self.request_to_blocks[request_id][block_idx_in_seq]
slots.append(physical_block_id * self.block_size + block_offset)
self.request_to_len[request_id] = new_total_len
return torch.tensor(slots, dtype=torch.int32)
def ref_paged_caching(k_pool, v_pool, k_new, v_new, slots, block_size):
"""Reference implementation for incremental caching."""
for i in range(k_new.shape[0]):
slot = slots[i].item()
b_id = slot // block_size
off = slot % block_size
k_pool[b_id, :, off, :] = k_new[i]
v_pool[b_id, :, off, :] = v_new[i]
return k_pool, v_pool
# ==============================================================================
# Test Operator Implementation
# ==============================================================================
def test(
handle,
device,
num_seqs,
max_step_len,
num_kv_heads,
head_size,
block_size,
num_rounds,
dtype=InfiniDtype.F16,
sync=None,
):
print(
f"Testing PagedCaching on {InfiniDeviceNames[device]} with "
f"seqs:{num_seqs}, max_step_len:{max_step_len}, num_kv_heads:{num_kv_heads}, head_size:{head_size}, "
f"block_size:{block_size}, rounds:{num_rounds}, dtype:{InfiniDtypeNames[dtype]}"
)
# 1. Initialize Global Cache Pool
num_blocks = 8192
manager = SimpleCacheManager(num_blocks, block_size)
k_cache_pool = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
v_cache_pool = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
# Reference pools (CPU/Torch)
k_pool_ref = k_cache_pool.torch_tensor().clone()
v_pool_ref = v_cache_pool.torch_tensor().clone()
for r in range(num_rounds):
# Prepare incremental data for this round
round_ntok_list = torch.randint(
1, max_step_len + 1, (num_seqs,), dtype=torch.int32
)
all_slots, all_k, all_v = [], [], []
for i in range(num_seqs):
n_new = round_ntok_list[i].item()
all_slots.append(manager.allocate_slots(i, n_new))
all_k.append(torch.randn(n_new, num_kv_heads, head_size))
all_v.append(torch.randn(n_new, num_kv_heads, head_size))
k_in_torch = torch.cat(all_k, dim=0)
v_in_torch = torch.cat(all_v, dim=0)
slots_torch = torch.cat(all_slots, dim=0)
k_in = TestTensor.from_torch(k_in_torch, dtype, device)
v_in = TestTensor.from_torch(v_in_torch, dtype, device)
slot_mapping = TestTensor.from_torch(slots_torch, InfiniDtype.I64, device)
# 2. Reference Calculation
def torch_caching():
nonlocal k_pool_ref, v_pool_ref
return ref_paged_caching(
k_pool_ref,
v_pool_ref,
k_in.torch_tensor(),
v_in.torch_tensor(),
slots_torch,
block_size,
)
torch_caching()
# 3. Infiniop Operator Execution
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreatePagedCachingDescriptor(
handle,
ctypes.byref(descriptor),
k_cache_pool.descriptor,
v_cache_pool.descriptor,
k_in.descriptor,
v_in.descriptor,
slot_mapping.descriptor,
)
)
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetPagedCachingWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, device)
def lib_caching():
check_error(
LIBINFINIOP.infiniopPagedCaching(
descriptor,
workspace.data(),
workspace_size.value,
k_cache_pool.data(),
v_cache_pool.data(),
k_in.data(),
v_in.data(),
slot_mapping.data(),
None,
)
)
lib_caching()
if sync:
sync()
# 4. Validation
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
# Check a small slice of the updated cache
debug(k_cache_pool.actual_tensor(), k_pool_ref, atol=atol, rtol=rtol)
assert torch.allclose(
k_cache_pool.actual_tensor(), k_pool_ref, atol=atol, rtol=rtol
)
assert torch.allclose(
v_cache_pool.actual_tensor(), v_pool_ref, atol=atol, rtol=rtol
)
# 5. Profiling
if PROFILE:
profile_operation(
f"Torch_R{r}",
lambda: torch_caching(),
device,
NUM_PRERUN,
NUM_ITERATIONS,
)
profile_operation(
f" Lib_R{r}", lambda: lib_caching(), device, NUM_PRERUN, NUM_ITERATIONS
)
check_error(LIBINFINIOP.infiniopDestroyPagedCachingDescriptor(descriptor))
# ==============================================================================
# Main Execution
# ==============================================================================
if __name__ == "__main__":
args = get_args()
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
)
from enum import Enum, auto
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# x_shape, w_shape, symmetric, bias_exit, y_shape
((8, 8), True),
((128, 512), True),
((128, 128), True),
((256, 1024), False),
((256, 2048), True),
((1024, 2048), False),
]
# Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 5e-2},
InfiniDtype.BF16: {"atol": 1e-3, "rtol": 5e-2},
InfiniDtype.F32: {"atol": 3e-5, "rtol": 5e-3},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
def per_token_quant_int8_torch(x, symmetric):
if symmetric:
x = x.float()
absmax = x.abs().max(dim=-1).values
absmax = absmax.clamp_min(1e-10).unsqueeze(-1)
scale_x = absmax / 127
x_q = x.mul(127 / absmax)
x_q = torch.round(x_q).to(torch.int8)
return x_q, scale_x, None
else:
w = x.float()
w_min = w.min(dim=-1, keepdim=True)[0]
w_max = w.max(dim=-1, keepdim=True)[0]
w_scale = (w_max - w_min) / 255.0
w_scale = torch.clamp(w_scale, min=1e-8)
w_zero = -w_min / w_scale - 128.0
w_q = torch.round(w / w_scale + w_zero)
w_q = torch.clamp(w_q, -128, 127)
w_packed = w_q.to(torch.int8)
return w_packed, w_scale, w_zero
def test(
handle,
device,
x_shape,
symmetric,
dtype=InfiniDtype.F16,
sync=None,
):
print(
f"Testing Per Channel Quant Int8 on {InfiniDeviceNames[device]} with x_shape:{x_shape}, symmetric:{symmetric} , dtype:{InfiniDtypeNames[dtype]}"
)
M, K = x_shape
x = TestTensor(x_shape, None, dtype, device)
x_p, x_s, x_z = per_token_quant_int8_torch(x.torch_tensor(), symmetric)
x_packed = TestTensor(x_shape, None, InfiniDtype.I8, device, mode="zeros")
x_scale = TestTensor((M, 1), None, InfiniDtype.F32, device)
if symmetric:
x_zero = None
else:
x_zero = TestTensor((M, 1), None, InfiniDtype.F32, device)
if sync is not None:
sync()
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreatePerChannelQuantI8Descriptor(
handle,
ctypes.byref(descriptor),
x_packed.descriptor,
x_scale.descriptor,
None if symmetric else x_zero.descriptor,
x.descriptor,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_packed.destroy_desc()
x_scale.destroy_desc()
if symmetric == False:
x_zero.destroy_desc()
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetPerChannelQuantI8WorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, x.device)
def lib_per_channel_quant_int8():
check_error(
LIBINFINIOP.infiniopPerChannelQuantI8(
descriptor,
workspace.data(),
workspace_size.value,
x_packed.data(),
x_scale.data(),
None if symmetric else x_zero.data(),
x.data(),
None,
)
)
lib_per_channel_quant_int8()
if sync is not None:
sync()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(x_packed.actual_tensor(), x_p, atol=atol, rtol=rtol)
debug(x_scale.actual_tensor(), x_s, atol=atol, rtol=rtol)
if symmetric == False:
debug(x_zero.actual_tensor(), x_z, atol=atol, rtol=rtol)
if symmetric:
assert (torch.allclose(x_packed.actual_tensor(), x_p, atol=2, rtol=2) and
torch.allclose(x_scale.actual_tensor(), x_s, atol=atol, rtol=rtol))
else:
assert (torch.allclose(x_packed.actual_tensor(), x_p, atol=2, rtol=2) and
torch.allclose(x_scale.actual_tensor(), x_s, atol=atol, rtol=rtol) and
torch.allclose(x_zero.actual_tensor(), x_z, atol=atol, rtol=rtol))
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: per_token_quant_int8_torch(x.torch_tensor(), symmetric), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_per_channel_quant_int8(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(LIBINFINIOP.infiniopDestroyPerChannelQuantI8Descriptor(descriptor))
if __name__ == "__main__":
args = get_args()
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
)
from enum import Enum, auto
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# x_shape, w_shape, y_shape, alpha, beta
((128, 512), (512, 1024), (128, 1024)),
((256, 1024), (1024, 2048), (256, 2048)),
((1024, 2048), (2048, 1024), (1024, 1024)),
]
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE = auto()
# Inplace options applied for each test case in _TEST_CASES_
_INPLACE = [
Inplace.INPLACE,
]
_TEST_CASES = [
test_case + (inplace_item,)
for test_case in _TEST_CASES_
for inplace_item in _INPLACE
]
# Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 3e-1, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 3e-1, "rtol": 1e-2},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
if bias is not None:
o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1) + bias
else:
o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1)
return o.to(out_dtype)
def test(
handle,
device,
x_shape,
w_shape,
y_shape,
inplace=Inplace.OUT_OF_PLACE,
dtype=InfiniDtype.BF16,
sync=None,
):
print(
f"Testing Linear on {InfiniDeviceNames[device]} with x_shape:{x_shape}, w_shape:{w_shape}, inplace:{inplace} dtype:{InfiniDtypeNames[dtype]}"
)
M, K = x_shape
N = w_shape[1]
x_packed = to_int8(torch.randn((M, K), device="cuda") * 5)
weights = to_int8(torch.randn((N, K), device="cuda").t() * 5)
x_scale = torch.randn((M,), device="cuda", dtype=torch.float32)
weights_scale = torch.randn((N,), device="cuda", dtype=torch.float32)
bias = torch.randn((N,), device="cuda", dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16) * 10
ans = torch_scaled_mm(x_packed, weights, x_scale, weights_scale, torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, bias=bias)
x_packed = TestTensor(
(M, K), x_packed.stride(), InfiniDtype.I8, device, mode="manual", set_tensor=x_packed
)
x_scale = TestTensor(
(M,), x_scale.stride(), InfiniDtype.F32, device, mode="manual", set_tensor=x_scale
)
weights = TestTensor(
(K, N), weights.stride(), InfiniDtype.I8, device, mode="manual", set_tensor=weights
)
weights_scale = TestTensor(
(N,), weights_scale.stride(), InfiniDtype.F32, device, mode="manual", set_tensor=weights_scale
)
y = TestTensor(y_shape, None, dtype, device)
bias = TestTensor((N,), bias.stride(), dtype, device, mode="manual", set_tensor=bias)
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreateI8GemmDescriptor(
handle,
ctypes.byref(descriptor),
y.descriptor,
bias.descriptor,
x_packed.descriptor,
x_scale.descriptor,
weights.descriptor,
weights_scale.descriptor,
)
)
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetI8GemmWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, x_packed.device)
def lib_linear():
check_error(
LIBINFINIOP.infiniopI8Gemm(
descriptor,
workspace.data(),
workspace_size.value,
y.data(),
bias.data(),
x_packed.data(),
x_scale.data(),
weights.data(),
weights_scale.data(),
None,
)
)
lib_linear()
if sync is not None:
sync()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(y.actual_tensor(), ans, atol=atol, rtol=rtol)
assert torch.allclose(y.actual_tensor(), ans, atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: torch_scaled_mm(x_packed, weights, x_scale, weights_scale, torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, bias=bias), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_linear(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(LIBINFINIOP.infiniopDestroyI8GemmDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# Format: (input_shape, output_shape)
# Referencing vLLM kernel Silu_and_Mul interface:
# input_shape is [..., 2*d], output_shape is [..., d]
_TEST_CASES = [
# input_shape, output_shape
((2, 8), (2, 4)),
((1024, 1024), (1024, 512)),
((16, 8192), (16, 4096)),
((2, 128, 2048), (2, 128, 1024)),
((8, 1, 4096), (8, 1, 2048)),
((2, 4, 16, 256), (2, 4, 16, 128)),
]
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3},
InfiniDtype.F32: {"atol": 1e-6, "rtol": 1e-6},
InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 100
# PyTorch reference: silu(gate) * up where [gate, up] = split(input)
def silu_and_mul_torch(out, input_tensor):
"""
Computes the SwiGLU activation function: SiLU(gate) * up.
"""
# Split the last dimension into two halves:
# the first half is 'gate', the second is 'up'
d = input_tensor.shape[-1] // 2
gate = input_tensor[..., :d]
up = input_tensor[..., d:]
# Apply SiLU to the gate and multiply by the up projection
torch.mul(torch.nn.functional.silu(gate), up, out=out)
# ==============================================================================
# Test Logic
# ==============================================================================
def test(
handle,
device,
input_shape,
output_shape,
dtype=InfiniDtype.F16,
sync=None,
):
print(
f"Testing SiluAndMul on {InfiniDeviceNames[device]} with "
f"input_shape:{input_shape} output_shape:{output_shape} dtype:{InfiniDtypeNames[dtype]}"
)
a = TestTensor(input_shape, None, dtype, device)
c = TestTensor(output_shape, None, dtype, device, mode="zeros")
ans = TestTensor(output_shape, None, dtype, device, mode="zeros")
# Only support contiguous Tensor
if not (
a.torch_tensor().is_contiguous()
and c.torch_tensor().is_contiguous()
and ans.torch_tensor().is_contiguous()
):
raise ValueError("This operator only supports contiguous memory layout.")
# PyTorch answer reference
def torch_silu_and_mul_reference():
silu_and_mul_torch(ans.torch_tensor(), a.torch_tensor())
torch_silu_and_mul_reference()
if sync is not None:
sync()
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreateSiluAndMulDescriptor(
handle,
ctypes.byref(descriptor),
c.descriptor,
a.descriptor,
)
)
for tensor in [a, c]:
tensor.destroy_desc()
# Workspace
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetSiluAndMulWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, device)
def lib_op():
check_error(
LIBINFINIOP.infiniopSiluAndMul(
descriptor,
workspace.data(),
workspace_size.value,
c.data(),
a.data(),
None,
)
)
lib_op()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol)
assert torch.allclose(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
profile_operation(
"PyTorch",
lambda: torch_silu_and_mul_reference(),
device,
NUM_PRERUN,
NUM_ITERATIONS,
)
profile_operation(
" lib", lambda: lib_op(), device, NUM_PRERUN, NUM_ITERATIONS
)
check_error(LIBINFINIOP.infiniopDestroySiluAndMulDescriptor(descriptor))
# ==============================================================================
# Main Execution
# ==============================================================================
if __name__ == "__main__":
args = get_args()
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mSiluAndMul Test passed!\033[0m")
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
)
from enum import Enum, auto
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# x_shape = [M,K], w_shape = [N, K], sym, y_shape = [M, N]
((100, 3584), (10752, 3584), True, (100, 10752)),
((1000, 3584), (10752, 3584), True, (1000, 10752)),
((1, 3584), (10752, 3584), True, (1, 10752)),
((2000, 3584), (10752, 3584), True, (2000, 10752)),
]
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE = auto()
# Inplace options applied for each test case in _TEST_CASES_
_INPLACE = [
Inplace.INPLACE,
]
_TEST_CASES = [
test_case + (inplace_item,)
for test_case in _TEST_CASES_
for inplace_item in _INPLACE
]
# Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 3e-1, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 3e-1, "rtol": 1e-2},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
def mm(x, w, bias, out_dtype):
return (torch.matmul(x, w + bias)).to(out_dtype)
def scaled_mm(x, w_p, w_s, bias, out_dtype):
return (
torch.matmul(x.to(torch.float32), w_p.to(torch.float32)) * w_s.view(1, -1)
+ bias
).to(out_dtype)
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
if bias is not None:
o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1) + bias
else:
o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1)
return o.to(out_dtype)
def per_token_quant_int8_torch(x):
x = x.float()
absmax = x.abs().max(dim=-1).values
absmax = absmax.clamp_min(1e-10).unsqueeze(-1)
scale_x = absmax / 127
x_q = x.mul(127 / absmax)
x_q = torch.round(x_q).to(torch.int8)
return x_q, scale_x
def test(
handle,
device,
x_shape,
w_shape,
symmetric,
y_shape,
inplace=Inplace.OUT_OF_PLACE,
dtype=InfiniDtype.BF16,
sync=None,
):
print(
f"Testing Linear on {InfiniDeviceNames[device]} with x_shape:{x_shape}, w_shape:{w_shape}, symmetric:{symmetric}, inplace:{inplace} dtype:{InfiniDtypeNames[dtype]}"
)
M, K = x_shape
N = w_shape[0]
x = TestTensor(x_shape, None, dtype, device)
x_packed = TestTensor(x_shape, None, InfiniDtype.I8, device, mode="zeros")
x_scale = TestTensor((M, 1), None, InfiniDtype.F32, device)
dev = x.torch_tensor().device
weights_packed = to_int8(torch.randn(w_shape, device=dev).t() * 5)
weights_scale = torch.randn((N, 1), device=dev, dtype=torch.float32)
bias = (
torch.randn(
(N,),
device=dev,
dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16,
)
* 10
)
w_packed = TestTensor(
(K, N),
weights_packed.stride(),
InfiniDtype.I8,
device,
mode="manual",
set_tensor=weights_packed,
)
w_scale = TestTensor(
(N, 1),
weights_scale.stride(),
InfiniDtype.F32,
device,
mode="manual",
set_tensor=weights_scale,
)
weights = w_packed.torch_tensor() * w_scale.torch_tensor().view(1, -1)
y = TestTensor(y_shape, None, dtype, device)
bias = TestTensor(
(N,), bias.stride(), dtype, device, mode="manual", set_tensor=bias
)
x_mm = x.torch_tensor().to(
torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16
)
w_mm = weights.to(torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16)
quant_descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreatePerChannelQuantI8Descriptor(
handle,
ctypes.byref(quant_descriptor),
x_packed.descriptor,
x_scale.descriptor,
None,
x.descriptor,
)
)
quant_workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetPerChannelQuantI8WorkspaceSize(
quant_descriptor, ctypes.byref(quant_workspace_size)
)
)
quant_workspace = TestWorkspace(quant_workspace_size.value, x.device)
def lib_per_channel_quant_int8():
check_error(
LIBINFINIOP.infiniopPerChannelQuantI8(
quant_descriptor,
quant_workspace.data(),
quant_workspace_size.value,
x_packed.data(),
x_scale.data(),
None,
x.data(),
None,
)
)
scaled_mm_descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreateI8GemmDescriptor(
handle,
ctypes.byref(scaled_mm_descriptor),
y.descriptor,
bias.descriptor,
x_packed.descriptor,
x_scale.descriptor,
w_packed.descriptor,
w_scale.descriptor,
)
)
scaled_mm_workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetI8GemmWorkspaceSize(
scaled_mm_descriptor, ctypes.byref(scaled_mm_workspace_size)
)
)
scaled_mm_workspace = TestWorkspace(scaled_mm_workspace_size.value, x_packed.device)
def lib_linear():
check_error(
LIBINFINIOP.infiniopI8Gemm(
scaled_mm_descriptor,
scaled_mm_workspace.data(),
scaled_mm_workspace_size.value,
y.data(),
bias.data(),
x_packed.data(),
x_scale.data(),
w_packed.data(),
w_scale.data(),
None,
)
)
def lib_w8a8int8_linearFunction():
lib_per_channel_quant_int8()
lib_linear()
def lib_torch_mm():
mm(
x_mm,
w_mm,
bias.torch_tensor(),
out_dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16,
)
x_p, x_s = per_token_quant_int8_torch(x.torch_tensor())
lib_w8a8int8_linearFunction()
scaled_mm_torch = torch_scaled_mm(
x_p,
w_packed.torch_tensor(),
x_s,
w_scale.torch_tensor(),
torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16,
bias=bias.torch_tensor(),
)
mm_torch = scaled_mm(
x.torch_tensor(),
w_packed.torch_tensor(),
w_scale.torch_tensor(),
bias.torch_tensor(),
out_dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16,
)
if sync is not None:
sync()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(y.actual_tensor(), mm_torch, atol=atol, rtol=rtol)
# The quantization test did not normalize the test data, leading to large errors; the error check has been temporarily removed.
def profile_operation(name, func, device, num_prerun, num_iterations):
# Warm up
for _ in range(num_prerun):
func()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(num_iterations):
func()
end.record()
torch.cuda.synchronize()
elapsed = start.elapsed_time(end)
print(
f"{name} took {elapsed / num_iterations:.6f} ms over {num_iterations} iterations"
)
# Profiling workflow
if PROFILE:
profile_operation(
"PyTorch mm ",
lambda: lib_torch_mm(),
device,
NUM_PRERUN,
NUM_ITERATIONS,
)
profile_operation(
"lib total ",
lambda: lib_w8a8int8_linearFunction(),
device,
NUM_PRERUN,
NUM_ITERATIONS,
)
profile_operation(
"lib quant ",
lambda: lib_per_channel_quant_int8(),
device,
NUM_PRERUN,
NUM_ITERATIONS,
)
profile_operation(
"lib scaled mm ",
lambda: lib_linear(),
device,
NUM_PRERUN,
NUM_ITERATIONS,
)
check_error(LIBINFINIOP.infiniopDestroyI8GemmDescriptor(scaled_mm_descriptor))
check_error(
LIBINFINIOP.infiniopDestroyPerChannelQuantI8Descriptor(quant_descriptor)
)
if __name__ == "__main__":
args = get_args()
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
Subproject commit 55f93686c01528224f448c19128836e7df245f72
......@@ -11,6 +11,7 @@ set_encodings("utf-8")
add_includedirs("include")
add_includedirs("third_party/spdlog/include")
add_includedirs("third_party/nlohmann_json/single_include/")
if is_mode("debug") then
add_defines("DEBUG_MODE")
......@@ -19,7 +20,7 @@ end
if is_plat("windows") then
set_runtimes("MD")
add_ldflags("/utf-8", {force = true})
add_cxflags("/utf-8", {force = true})
add_cxxflags("/utf-8", {force = true})
end
-- CPU
......@@ -66,6 +67,16 @@ if has_config("cudnn") then
add_defines("ENABLE_CUDNN_API")
end
option("cutlass")
set_default(false)
set_showmenu(true)
set_description("Whether to compile cutlass for Nvidia GPU")
option_end()
if has_config("cutlass") then
add_defines("ENABLE_CUTLASS_API")
end
option("cuda_arch")
set_showmenu(true)
set_description("Set CUDA GPU architecture (e.g. sm_90)")
......@@ -104,11 +115,29 @@ option("iluvatar-gpu")
set_description("Whether to compile implementations for Iluvatar GPU")
option_end()
option("ivcore-20")
set_default(false)
set_showmenu(true)
set_description("Use ivcore20")
option_end()
if has_config("iluvatar-gpu") then
add_defines("ENABLE_ILUVATAR_API")
includes("xmake/iluvatar.lua")
end
-- ali
option("ali-ppu")
set_default(false)
set_showmenu(true)
set_description("Whether to compile implementations for Ali PPU")
option_end()
if has_config("ali-ppu") then
add_defines("ENABLE_ALI_API")
includes("xmake/ali.lua")
end
-- qy
option("qy-gpu")
set_default(false)
......@@ -189,6 +218,18 @@ if has_config("ninetoothed") then
add_defines("ENABLE_NINETOOTHED")
end
-- cuda graph
option("graph")
set_default(false)
set_showmenu(true)
set_description("Whether to use device graph instantiating feature, such as cuda graph for nvidia")
option_end()
if has_config("graph") then
add_defines("USE_INFINIRT_GRAPH")
end
-- InfiniCCL
option("ccl")
set_default(false)
......@@ -208,14 +249,15 @@ target("infini-utils")
set_warnings("all", "error")
if is_plat("windows") then
add_cxflags("/wd4068")
add_cxxflags("/wd4068")
if has_config("omp") then
add_cxflags("/openmp")
add_cxxflags("/openmp")
end
else
add_cxflags("-fPIC", "-Wno-unknown-pragmas")
add_cxxflags("-fPIC", "-Wno-unknown-pragmas")
if has_config("omp") then
add_cxflags("-fopenmp")
add_cxxflags("-fopenmp")
add_ldflags("-fopenmp", {force = true})
end
end
......@@ -247,6 +289,9 @@ target("infinirt")
if has_config("iluvatar-gpu") then
add_deps("infinirt-iluvatar")
end
if has_config("ali-ppu") then
add_deps("infinirt-ali")
end
if has_config("qy-gpu") then
add_deps("infinirt-qy")
add_files("build/.objs/infinirt-qy/rules/qy.cuda/src/infinirt/cuda/*.cu.o", {public = true})
......@@ -258,6 +303,10 @@ target("infinirt")
add_deps("infinirt-hygon")
end
set_languages("cxx17")
if not is_plat("windows") then
add_cxflags("-fPIC")
add_cxxflags("-fPIC")
end
set_installdir(os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini"))
add_files("src/infinirt/*.cc")
add_installfiles("include/infinirt.h", {prefixdir = "include"})
......@@ -276,9 +325,13 @@ target("infiniop")
if has_config("iluvatar-gpu") then
add_deps("infiniop-iluvatar")
end
if has_config("ali-ppu") then
add_deps("infiniop-ali")
end
if has_config("qy-gpu") then
add_deps("infiniop-qy")
add_files("build/.objs/infiniop-qy/rules/qy.cuda/src/infiniop/ops/*/nvidia/*.cu.o", {public = true})
add_files("build/.objs/infiniop-qy/rules/qy.cuda/src/infiniop/ops/*/*/nvidia/*.cu.o", {public = true})
add_files("build/.objs/infiniop-qy/rules/qy.cuda/src/infiniop/devices/nvidia/*.cu.o", {public = true})
end
......@@ -302,7 +355,7 @@ target("infiniop")
end
set_languages("cxx17")
add_files("src/infiniop/devices/handle.cc")
add_files("src/infiniop/ops/*/operator.cc")
add_files("src/infiniop/ops/*/operator.cc", "src/infiniop/ops/*/*/operator.cc")
add_files("src/infiniop/*.cc")
set_installdir(os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini"))
......@@ -331,6 +384,9 @@ target("infiniccl")
if has_config("iluvatar-gpu") then
add_deps("infiniccl-iluvatar")
end
if has_config("ali-ppu") then
add_deps("infiniccl-ali")
end
if has_config("qy-gpu") then
add_deps("infiniccl-qy")
add_files("build/.objs/infiniccl-qy/rules/qy.cuda/src/infiniccl/cuda/*.cu.o", {public = true})
......@@ -380,6 +436,7 @@ target("infinicore_cpp_api")
add_files("src/infinicore/context/*.cc")
add_files("src/infinicore/context/*/*.cc")
add_files("src/infinicore/tensor/*.cc")
add_files("src/infinicore/graph/*.cc")
add_files("src/infinicore/nn/*.cc")
add_files("src/infinicore/ops/*/*.cc")
add_files("src/utils/*.cc")
......@@ -408,6 +465,8 @@ target("_infinicore")
add_packages("pybind11")
set_languages("cxx17")
add_deps("infinicore_cpp_api")
set_kind("shared")
local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")
add_includedirs(INFINI_ROOT.."/include", { public = true })
......@@ -415,14 +474,7 @@ target("_infinicore")
add_linkdirs(INFINI_ROOT.."/lib")
add_links("infiniop", "infinirt", "infiniccl")
add_files("src/infinicore/*.cc")
add_files("src/infinicore/context/*.cc")
add_files("src/infinicore/context/*/*.cc")
add_files("src/infinicore/tensor/*.cc")
add_files("src/infinicore/nn/*.cc")
add_files("src/infinicore/ops/*/*.cc")
add_files("src/infinicore/pybind11/**.cc")
add_files("src/utils/*.cc")
set_installdir("python/infinicore")
target_end()
......
local CUDNN_ROOT = os.getenv("CUDNN_ROOT") or os.getenv("CUDNN_HOME") or os.getenv("CUDNN_PATH")
if CUDNN_ROOT ~= nil then
add_includedirs(CUDNN_ROOT .. "/include")
end
local CUTLASS_ROOT = os.getenv("CUTLASS_ROOT") or os.getenv("CUTLASS_HOME") or os.getenv("CUTLASS_PATH")
if CUTLASS_ROOT ~= nil then
add_includedirs(CUTLASS_ROOT)
end
target("infiniop-ali")
set_kind("static")
add_deps("infini-utils")
on_install(function (target) end)
set_policy("build.cuda.devlink", true)
set_toolchains("cuda")
add_links("cudart", "cublas")
if has_config("cudnn") then
add_links("cudnn")
end
on_load(function (target)
import("lib.detect.find_tool")
local nvcc = find_tool("nvcc")
if nvcc ~= nil then
if is_plat("windows") then
nvcc_path = os.iorun("where nvcc"):match("(.-)\r?\n")
else
nvcc_path = nvcc.program
end
target:add("linkdirs", path.directory(path.directory(nvcc_path)) .. "/lib64/stubs")
target:add("links", "cuda")
end
end)
if is_plat("windows") then
add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler")
add_cuflags("-Xcompiler=/W3", "-Xcompiler=/WX")
add_cxxflags("/FS")
if CUDNN_ROOT ~= nil then
add_linkdirs(CUDNN_ROOT .. "\\lib\\x64")
end
else
add_cuflags("-Xcompiler=-Wall", "-Xcompiler=-Werror")
add_cuflags("-Xcompiler=-fPIC")
add_cuflags("--extended-lambda")
add_culdflags("-Xcompiler=-fPIC")
add_cxflags("-fPIC")
add_cxxflags("-fPIC")
add_cflags("-fPIC")
add_cuflags("--expt-relaxed-constexpr")
if CUDNN_ROOT ~= nil then
add_linkdirs(CUDNN_ROOT .. "/lib")
end
end
add_cuflags("-Xcompiler=-Wno-error=deprecated-declarations", "-Xcompiler=-Wno-error=unused-function")
local arch_opt = get_config("cuda_arch")
if arch_opt and type(arch_opt) == "string" then
for _, arch in ipairs(arch_opt:split(",")) do
arch = arch:trim()
local compute = arch:gsub("sm_", "compute_")
add_cuflags("-gencode=arch=" .. compute .. ",code=" .. arch)
end
else
add_cugencodes("native")
end
set_languages("cxx17")
add_files("../src/infiniop/devices/nvidia/*.cu", "../src/infiniop/ops/*/nvidia/*.cu")
if has_config("ninetoothed") then
add_files("../build/ninetoothed/*.c", "../build/ninetoothed/*.cpp")
end
target_end()
target("infinirt-ali")
set_kind("static")
add_deps("infini-utils")
on_install(function (target) end)
set_policy("build.cuda.devlink", true)
set_toolchains("cuda")
add_links("cudart")
if is_plat("windows") then
add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler")
add_cxxflags("/FS")
else
add_cuflags("-Xcompiler=-fPIC", "-Xcompiler=-shared")
add_culdflags("-Xcompiler=-fPIC", "-Xcompiler=-shared")
add_cxflags("-fPIC", "-shared")
add_cxxflags("-fPIC", "-shared")
add_shflags("-fPIC")
end
set_languages("cxx17")
add_files("../src/infinirt/cuda/*.cu")
target_end()
target("infiniccl-ali")
set_kind("static")
add_deps("infinirt")
on_install(function (target) end)
if has_config("ccl") then
set_policy("build.cuda.devlink", true)
set_toolchains("cuda")
add_links("cudart")
if not is_plat("windows") then
add_cuflags("-Xcompiler=-fPIC")
add_culdflags("-Xcompiler=-fPIC")
add_cxflags("-fPIC")
add_cxxflags("-fPIC")
local nccl_root = os.getenv("NCCL_ROOT")
if nccl_root then
add_includedirs(nccl_root .. "/include")
add_links(nccl_root .. "/lib/libnccl.so")
else
add_links("nccl") -- Fall back to default nccl linking
end
add_files("../src/infiniccl/cuda/*.cu")
else
print("[Warning] NCCL is not supported on Windows")
end
end
set_languages("cxx17")
target_end()
......@@ -44,6 +44,7 @@ target("infiniop-ascend")
on_install(function (target) end)
add_cxflags("-lstdc++ -fPIC")
add_cxxflags("-lstdc++ -fPIC")
set_warnings("all", "error")
set_languages("cxx17")
......@@ -62,6 +63,7 @@ target("infinirt-ascend")
-- Add files
add_files("$(projectdir)/src/infinirt/ascend/*.cc")
add_cxflags("-lstdc++ -Wall -Werror -fPIC")
add_cxxflags("-lstdc++ -Wall -Werror -fPIC")
target_end()
target("infiniccl-ascend")
......@@ -76,5 +78,6 @@ target("infiniccl-ascend")
add_links("libhccl.so")
add_files("../src/infiniccl/ascend/*.cc")
add_cxflags("-lstdc++ -fPIC")
add_cxxflags("-lstdc++ -fPIC")
end
target_end()
......@@ -41,6 +41,7 @@ target("infiniop-cambricon")
on_install(function (target) end)
add_cxflags("-lstdc++ -fPIC")
add_cxxflags("-lstdc++ -fPIC")
set_warnings("all", "error")
set_languages("cxx17")
......@@ -59,6 +60,7 @@ target("infinirt-cambricon")
-- Add include dirs
add_files("../src/infinirt/bang/*.cc")
add_cxflags("-lstdc++ -Wall -Werror -fPIC")
add_cxxflags("-lstdc++ -Wall -Werror -fPIC")
target_end()
target("infiniccl-cambricon")
......@@ -89,6 +91,7 @@ target("infiniccl-cambricon")
add_files("../src/infiniccl/cambricon/*.cc")
add_cxflags("-fPIC")
add_cxxflags("-fPIC")
add_ldflags("-fPIC")
else
print("[Warning] CNCL is currently only supported on Linux")
......
......@@ -6,14 +6,15 @@ target("infiniop-cpu")
set_warnings("all", "error")
if is_plat("windows") then
add_cxflags("/wd4068")
add_cxxflags("/wd4068")
if has_config("omp") then
add_cxflags("/openmp")
add_cxxflags("/openmp")
end
else
add_cxflags("-fPIC", "-Wno-unknown-pragmas")
add_cxxflags("-fPIC", "-Wno-unknown-pragmas")
if has_config("omp") then
add_cxflags("-fopenmp")
add_cxxflags("-fopenmp")
add_ldflags("-fopenmp")
end
end
......@@ -32,6 +33,7 @@ target("infinirt-cpu")
if not is_plat("windows") then
add_cxflags("-fPIC")
add_cxxflags("-fPIC")
end
set_languages("cxx17")
......
......@@ -60,6 +60,7 @@ target("infiniop-hygon")
add_cuflags("-fPIC", "-std=c++17", {force = true})
add_culdflags("-fPIC")
add_cxflags("-fPIC")
add_cxxflags("-fPIC")
-- 添加海光DCU特定的编译标志
-- 检测实际GPU架构,如果未指定则默认使用gfx906
......@@ -71,7 +72,7 @@ target("infiniop-hygon")
add_files("../src/infiniop/devices/nvidia/*.cu", "../src/infiniop/ops/*/nvidia/*.cu")
if has_config("ninetoothed") then
add_files("../build/ninetoothed/*.c", {cxflags = {"-Wno-return-type"}})
add_files("../build/ninetoothed/*.c", "../build/ninetoothed/*.cpp", {cxxflags = {"-Wno-return-type"}})
end
target_end()
......@@ -100,6 +101,7 @@ target("infinirt-hygon")
add_cuflags("-fPIC", "-std=c++17", {force = true})
add_culdflags("-fPIC")
add_cxflags("-fPIC")
add_cxxflags("-fPIC")
-- 添加海光DCU特定的编译标志
-- 检测实际GPU架构,如果未指定则默认使用gfx906
......@@ -135,6 +137,7 @@ target("infiniccl-hygon")
add_cuflags("-fPIC", "-std=c++17", {force = true})
add_culdflags("-fPIC")
add_cxflags("-fPIC")
add_cxxflags("-fPIC")
-- 添加海光DCU特定的编译标志
-- 检测实际GPU架构,如果未指定则默认使用gfx906
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment