Commit 054763bc authored by PanZezhong's avatar PanZezhong
Browse files

issue/115 将matmul改名为gemm

parent 2ede6b81
#include "matmul_maca.h"
#include "gemm_maca.h"
#include "../../../devices/maca/common_maca.h"
#include "../../../devices/maca/maca_handle.h"
namespace op::matmul::maca {
namespace op::gemm::maca {
struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal;
......@@ -106,4 +106,4 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_SUCCESS;
}
} // namespace op::matmul::maca
} // namespace op::gemm::maca
#ifndef __GEMM_MACA_H__
#define __GEMM_MACA_H__
#include "../gemm.h"
DESCRIPTOR(maca)
#endif // __GEMM_MACA_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/matmul.h"
#include "infiniop/ops/gemm.h"
#ifdef ENABLE_CPU_API
#include "cpu/matmul_cpu.h"
#include "cpu/gemm_cpu.h"
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/matmul_cuda.cuh"
#include "cuda/gemm_cuda.cuh"
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/matmul_bang.h"
#include "bang/gemm_bang.h"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/matmul_ascend.h"
#include "ascend/gemm_ascend.h"
#endif
#ifdef ENABLE_METAX_API
#include "maca/matmul_maca.h"
#include "maca/gemm_maca.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/matmul_kunlun.h"
#include "kunlun/gemm_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateMatmulDescriptor(
__C infiniStatus_t infiniopCreateGemmDescriptor(
infiniopHandle_t handle,
infiniopMatmulDescriptor_t *desc_ptr,
infiniopGemmDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::matmul::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::matmul::NAMESPACE::Descriptor **>(desc_ptr), \
c_desc, \
a_desc, \
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::gemm::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::gemm::NAMESPACE::Descriptor **>(desc_ptr), \
c_desc, \
a_desc, \
b_desc)
switch (handle->device) {
......@@ -66,13 +66,13 @@ __C infiniStatus_t infiniopCreateMatmulDescriptor(
}
__C infiniStatus_t
infiniopGetMatmulWorkspaceSize(
infiniopMatmulDescriptor_t desc,
infiniopGetGemmWorkspaceSize(
infiniopGemmDescriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::matmul::NAMESPACE::Descriptor *>(desc)->workspace_size; \
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspace_size; \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
......@@ -103,8 +103,8 @@ infiniopGetMatmulWorkspaceSize(
#undef GET
}
__C infiniStatus_t infiniopMatmul(
infiniopMatmulDescriptor_t desc,
__C infiniStatus_t infiniopGemm(
infiniopGemmDescriptor_t desc,
void *workspace, size_t workspace_size,
void *c,
const void *a,
......@@ -113,12 +113,12 @@ __C infiniStatus_t infiniopMatmul(
float beta,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::matmul::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, \
c, beta, \
a, b, alpha, \
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, \
c, beta, \
a, b, alpha, \
stream)
switch (desc->device_type) {
......@@ -150,11 +150,11 @@ __C infiniStatus_t infiniopMatmul(
}
__C infiniStatus_t
infiniopDestroyMatmulDescriptor(infiniopMatmulDescriptor_t desc) {
infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::matmul::NAMESPACE::Descriptor *>(desc); \
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
......
#ifndef __MATMUL_ASCEND_H__
#define __MATMUL_ASCEND_H__
#include "../matmul.h"
DESCRIPTOR(ascend)
#endif // __MATMUL_ASCEND_H__
#ifndef __MATMUL_BANG_H__
#define __MATMUL_BANG_H__
#include "../matmul.h"
DESCRIPTOR(bang)
#endif // __MATMUL_BANG_H__
#ifndef __MATMUL_CPU_H__
#define __MATMUL_CPU_H__
#include "../matmul.h"
DESCRIPTOR(cpu)
#endif // __MATMUL_CPU_H__
#ifndef __MATMUL_CUDA_CUH__
#define __MATMUL_CUDA_CUH__
#include "../matmul.h"
DESCRIPTOR(cuda)
#endif // __MATMUL_CUDA_CUH__
#ifndef __MATMUL_KUNLUN_H__
#define __MATMUL_KUNLUN_H__
#include "../matmul.h"
DESCRIPTOR(kunlun)
#endif // __MATMUL_KUNLUN_H__
#ifndef __MATMUL_MACA_H__
#define __MATMUL_MACA_H__
#include "../matmul.h"
DESCRIPTOR(maca)
#endif // __MATMUL_MACA_H__
......@@ -14,11 +14,11 @@ xmake build infiniop-test
- 生成测例
`/test/infiniop-test/`目录执行矩阵乘测例生成脚本,执行结束以后会在`/test/infiniop-test/`目录生成`matmul.gguf`测例文件。
`/test/infiniop-test/`目录执行矩阵乘测例生成脚本,执行结束以后会在`/test/infiniop-test/`目录生成`gemm.gguf`测例文件。
```bash
cd /test/infiniop-test/
python -m test_generate.testcases.matmul
python -m test_generate.testcases.gemm
```
- 测试测例
......@@ -29,10 +29,10 @@ python -m test_generate.testcases.matmul
infiniop-test --help
```
示例:在CPU上测试`matmul.gguf`测例文件,预热20次,测试1000次。
示例:在CPU上测试`gemm.gguf`测例文件,预热20次,测试1000次。
```bash
infiniop-test matmul.gguf --cpu --warmup 20 --run 1000
infiniop-test gemm.gguf --cpu --warmup 20 --run 1000
```
## 自定义测例
......
......@@ -6,7 +6,7 @@ from typing import List
from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides
def matmul(
def gemm(
a: np.ndarray,
b: np.ndarray,
alpha: float = 1.0,
......@@ -19,12 +19,12 @@ def matmul(
def random_tensor(shape, dtype):
rate = 1e-3 # 目前发现如果rate=1e-2还是无法全部通过测试
var = 0.5 * rate # 这样设置可以保证采样范围在[-5e-4, 5e-4]
rate = 1e-3
var = 0.5 * rate # 数值范围在[-5e-4, 5e-4]
return rate * np.random.rand(*shape).astype(dtype) - var
class MatmulTestCase(InfiniopTestCase):
class GemmTestCase(InfiniopTestCase):
def __init__(
self,
a: np.ndarray,
......@@ -36,7 +36,7 @@ class MatmulTestCase(InfiniopTestCase):
alpha: float,
beta: float,
):
super().__init__("matmul")
super().__init__("gemm")
self.a = a
self.stride_a = stride_a
self.b = b
......@@ -65,7 +65,7 @@ class MatmulTestCase(InfiniopTestCase):
test_writer.add_tensor(
test_writer.gguf_key("c"), self.c, raw_dtype=np_dtype_to_ggml(self.c.dtype)
)
ans = matmul(
ans = gemm(
self.a.astype(np.float64),
self.b.astype(np.float64),
self.alpha,
......@@ -78,10 +78,10 @@ class MatmulTestCase(InfiniopTestCase):
if __name__ == "__main__":
test_writer = InfiniopTestWriter("matmul.gguf")
test_writer = InfiniopTestWriter("gemm.gguf")
# a, stride_a, b, stride_b, c, stride_c, alpha, beta
test_cases = [
MatmulTestCase(
GemmTestCase(
random_tensor((4, 5), np.float32),
None,
random_tensor((5, 6), np.float32),
......@@ -91,7 +91,7 @@ if __name__ == "__main__":
1.0,
0.0,
),
MatmulTestCase(
GemmTestCase(
random_tensor((4, 5), np.float32),
gguf_strides(1, 4),
random_tensor((5, 6), np.float32),
......@@ -101,7 +101,7 @@ if __name__ == "__main__":
1.0,
1.0,
),
MatmulTestCase(
GemmTestCase(
random_tensor((4, 5), np.float16),
None,
random_tensor((5, 6), np.float16),
......@@ -111,7 +111,7 @@ if __name__ == "__main__":
1.0,
0.0,
),
MatmulTestCase(
GemmTestCase(
random_tensor((4, 5), np.float16),
gguf_strides(1, 4),
random_tensor((5, 6), np.float16),
......@@ -121,7 +121,7 @@ if __name__ == "__main__":
1.0,
1.0,
),
MatmulTestCase(
GemmTestCase(
random_tensor((1, 2048), np.float16),
gguf_strides(1, 2048),
random_tensor((2048, 2048), np.float16),
......@@ -131,7 +131,7 @@ if __name__ == "__main__":
1.0,
0.0,
),
MatmulTestCase(
GemmTestCase(
random_tensor((1, 2048), np.float32),
None,
random_tensor((2048, 2048), np.float32),
......@@ -141,7 +141,7 @@ if __name__ == "__main__":
1.0,
0.0,
),
MatmulTestCase(
GemmTestCase(
random_tensor((2, 4, 2048), np.float16),
None,
random_tensor((2, 2048, 2048), np.float16),
......@@ -151,7 +151,7 @@ if __name__ == "__main__":
1.0,
0.0,
),
MatmulTestCase(
GemmTestCase(
random_tensor((2, 4, 2048), np.float32),
None,
random_tensor((2, 2048, 2048), np.float32),
......@@ -161,7 +161,7 @@ if __name__ == "__main__":
1.0,
0.0,
),
MatmulTestCase(
GemmTestCase(
random_tensor((6, 2048), np.float32),
gguf_strides(1, 2048),
random_tensor((2048, 2560), np.float32),
......@@ -171,7 +171,7 @@ if __name__ == "__main__":
1.0,
1.0,
),
MatmulTestCase(
GemmTestCase(
random_tensor((4, 48, 64), np.float16),
None,
random_tensor((4, 64, 6), np.float16),
......@@ -181,7 +181,7 @@ if __name__ == "__main__":
1.0 / 8,
1.0,
),
MatmulTestCase(
GemmTestCase(
random_tensor((4, 48, 64), np.float32),
None,
random_tensor((4, 64, 6), np.float32),
......
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float, c_bool
import torch
import ctypes
import sys
import os
import time
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from libinfiniop import (
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
open_lib,
to_tensor,
get_test_devices,
check_error,
rearrange_tensor,
rearrange_if_needed,
create_workspace,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
)
from operatorspy.tests.test_utils import get_args
import torch
# constant for control whether profile the pytorch and lib functions
# NOTE: need to manually add synchronization function to the lib function,
# e.g., cudaDeviceSynchronize() for CUDA
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-2},
torch.float32: {"atol": 0, "rtol": 1e-3},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class GEMMDescriptor(Structure):
# ==============================================================================
# Definitions
# ==============================================================================
class GemmDescriptor(Structure):
_fields_ = [("device", c_int32)]
infiniopGEMMDescriptor_t = POINTER(GEMMDescriptor)
infiniopGemmDescriptor_t = POINTER(GemmDescriptor)
def gemm(
A, B, C=None, transA=False, transB=False, alpha=1.0, beta=0.0, dtype=torch.float32
):
A = A.T if transA else A
B = B.T if transB else B
result = alpha * torch.matmul(
A if dtype != torch.float16 else A.to(torch.float32),
B if dtype != torch.float16 else B.to(torch.float32),
).to(dtype)
if C is not None:
result += beta * C if dtype != torch.float16 else C.to(torch.float32)
if PROFILE:
torch.cuda.synchronize()
return result
# PyTorch implementation for matrix multiplication
def gemm(_c, beta, _a, _b, alpha):
a, b, c = _a.clone(), _b.clone(), _c.clone()
result_dtype = c.dtype
fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32))
return alpha * fp32_result.to(result_dtype) + beta * c
# The argument list should be (lib, handle, torch_device, <param list>, dtype)
# The <param list> should keep the same order as the one specified in _TEST_CASES
def test(
lib,
handle,
torch_device,
alpha,
beta,
transA,
transB,
a_shape,
b_shape,
c_shape,
y_shape,
a_stride=None,
b_stride=None,
c_stride=None,
y_stride=None,
dtype=torch.float16,
):
print(
f"Testing GEMM on {torch_device} with transA: {transA} transB: {transB} "
f"a_shape:{a_shape} b_shape:{b_shape} c_shape:{c_shape} y_shape:{y_shape} "
f"a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} y_stride:{y_stride} dtype:{dtype}"
f"Testing Gemm on {torch_device} with alpha:{alpha}, beta:{beta},"
f" a_shape:{a_shape}, b_shape:{b_shape}, c_shape:{c_shape},"
f" a_stride:{a_stride}, b_stride:{b_stride}, c_stride:{c_stride}, dtype:{dtype}"
)
# Initialize tensors
a = torch.rand(a_shape, dtype=dtype).to(torch_device)
b = torch.rand(b_shape, dtype=dtype).to(torch_device)
c = torch.rand(c_shape, dtype=dtype).to(torch_device) if c_shape else None
y = torch.rand(y_shape, dtype=dtype).to(torch_device)
c = torch.ones(c_shape, dtype=dtype).to(torch_device)
if a_stride is not None:
a = rearrange_tensor(a, a_stride)
if b_stride is not None:
b = rearrange_tensor(b, b_stride)
if c_stride is not None and c is not None:
c = rearrange_tensor(c, c_stride)
if y_stride is not None:
y = rearrange_tensor(y, y_stride)
# Compute the PyTorch reference result
ans = gemm(c, beta, a, b, alpha)
for i in range(NUM_PRERUN if PROFILE else 1):
ans = gemm(a, b, c, transA, transB, alpha, beta, dtype)
if PROFILE:
start_time = time.time()
for i in range(NUM_ITERATIONS):
_ = gemm(a, b, c, transA, transB, alpha, beta, dtype)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f"pytorch time: {elapsed :6f}")
a, b, c = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])
]
a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
a_tensor = to_tensor(a, lib)
b_tensor = to_tensor(b, lib)
c_tensor = to_tensor(c, lib) if c is not None else None
y_tensor = to_tensor(y, lib)
descriptor = infiniopGEMMDescriptor_t()
descriptor = infiniopGemmDescriptor_t()
check_error(
lib.infiniopCreateGEMMDescriptor(
lib.infiniopCreateGemmDescriptor(
handle,
ctypes.byref(descriptor),
y_tensor.descriptor,
c_tensor.descriptor,
a_tensor.descriptor,
b_tensor.descriptor,
c_tensor.descriptor if c_tensor else None,
alpha,
beta,
transA,
transB,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
a_tensor.descriptor.contents.invalidate()
b_tensor.descriptor.contents.invalidate()
if c_tensor is not None:
c_tensor.descriptor.contents.invalidate()
y_tensor.descriptor.contents.invalidate()
for tensor in [a_tensor, b_tensor, c_tensor]:
tensor.destroyDesc(lib)
workspace_size = ctypes.c_uint64(0)
# Get workspace size and create workspace
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetGEMMWorkspaceSize(descriptor, ctypes.byref(workspace_size))
lib.infiniopGetGemmWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = torch.zeros(int(workspace_size.value), dtype=torch.uint8).to(
torch_device
)
workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8))
workspace = create_workspace(workspace_size.value, a.device)
for i in range(NUM_PRERUN if PROFILE else 1):
# Execute infiniop gemm operator
def lib_gemm():
check_error(
lib.infiniopGEMM(
lib.infiniopGemm(
descriptor,
workspace_ptr,
workspace_size,
y_tensor.data,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
c_tensor.data,
a_tensor.data,
b_tensor.data,
c_tensor.data if c_tensor else None,
alpha,
beta,
None,
)
)
if PROFILE:
start_time = time.time()
for i in range(NUM_ITERATIONS):
check_error(
lib.infiniopGEMM(
descriptor,
workspace_ptr,
workspace_size,
y_tensor.data,
a_tensor.data,
b_tensor.data,
c_tensor.data if c_tensor else None,
None,
)
)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f" lib time: {elapsed :6f}")
assert torch.allclose(y, ans, atol=0, rtol=1e-2)
check_error(lib.infiniopDestroyGEMMDescriptor(descriptor))
lib_gemm()
def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for (
alpha,
beta,
transA,
transB,
a_shape,
b_shape,
c_shape,
y_shape,
a_stride,
b_stride,
c_stride,
y_stride,
) in test_cases:
# fmt: off
test(lib, handle, "cpu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float16)
test(lib, handle, "cpu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle)
# Validate results
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(c, ans, atol=atol, rtol=rtol)
assert torch.allclose(c, ans, atol=atol, rtol=rtol)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for (
alpha,
beta,
transA,
transB,
a_shape,
b_shape,
c_shape,
y_shape,
a_stride,
b_stride,
c_stride,
y_stride,
) in test_cases:
# fmt: off
test(lib, handle, "cuda", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float16)
test(lib, handle, "cuda", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float32)
# fmt: on
destroy_handle(lib, handle)
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for (
alpha,
beta,
transA,
transB,
a_shape,
b_shape,
c_shape,
y_shape,
a_stride,
b_stride,
c_stride,
y_stride,
) in test_cases:
# Profiling workflow
if PROFILE:
# fmt: off
test(lib, handle, "mlu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float16)
test(lib, handle, "mlu", alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride, dtype=torch.float32)
profile_operation("PyTorch", lambda: gemm(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_gemm(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
destroy_handle(lib, handle)
check_error(lib.infiniopDestroyGemmDescriptor(descriptor))
# ==============================================================================
# Main Execution
# ==============================================================================
if __name__ == "__main__":
test_cases = [
# alpha, beta, transA, transB, a_shape, b_shape, c_shape, y_shape, a_stride, b_stride, c_stride, y_stride
(
1.0,
1.0,
False,
False,
(1, 2048),
(2048, 2048),
(1, 2048),
(1, 2048),
None,
None,
None,
None,
),
(
1.0,
1.0,
True,
True,
(2048, 4),
(2048, 2048),
(4, 2048),
(4, 2048),
None,
None,
None,
None,
),
(
1.0,
1.0,
False,
True,
(1, 2048),
(1000, 2048),
(1000),
(1, 1000),
None,
None,
None,
None,
),
(
1.0,
1.0,
True,
False,
(2048, 4),
(2048, 2048),
(2048),
(4, 2048),
(4096, 1),
(4096, 1),
(2,),
(4096, 1),
),
(
1.0,
1.0,
False,
False,
(3, 1, 2048),
(3, 2048, 2048),
(1,),
(3, 1, 2048),
None,
None,
None,
None,
),
(
1.0,
1.0,
True,
False,
(2048, 4),
(2048, 2048),
None,
(4, 2048),
(4096, 1),
(4096, 1),
(2,),
(4096, 1),
),
]
args = get_args()
lib = open_lib()
lib.infiniopCreateGEMMDescriptor.restype = c_int32
lib.infiniopCreateGEMMDescriptor.argtypes = [
lib.infiniopCreateGemmDescriptor.restype = c_int32
lib.infiniopCreateGemmDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopGEMMDescriptor_t),
POINTER(infiniopGemmDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_float,
c_float,
c_bool,
c_bool,
]
lib.infiniopGetGEMMWorkspaceSize.restype = c_int32
lib.infiniopGetGEMMWorkspaceSize.argtypes = [
infiniopGEMMDescriptor_t,
POINTER(c_uint64),
lib.infiniopGetGemmWorkspaceSize.restype = c_int32
lib.infiniopGetGemmWorkspaceSize.argtypes = [
infiniopGemmDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopGEMM.restype = c_int32
lib.infiniopGEMM.argtypes = [
infiniopGEMMDescriptor_t,
lib.infiniopGemm.restype = c_int32
lib.infiniopGemm.argtypes = [
infiniopGemmDescriptor_t,
c_void_p,
c_uint64,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_float,
c_float,
c_void_p,
]
lib.infiniopDestroyGEMMDescriptor.restype = c_int32
lib.infiniopDestroyGEMMDescriptor.argtypes = [
infiniopGEMMDescriptor_t,
lib.infiniopDestroyGemmDescriptor.restype = c_int32
lib.infiniopDestroyGemmDescriptor.argtypes = [
infiniopGemmDescriptor_t,
]
if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if args.bang:
test_bang(lib, test_cases)
if not (args.cpu or args.cuda or args.bang):
test_cpu(lib, test_cases)
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
# Execute tests
for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
import torch
import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from libinfiniop import (
infiniopHandle_t,
infiniopTensorDescriptor_t,
open_lib,
to_tensor,
get_test_devices,
check_error,
rearrange_if_needed,
create_workspace,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-2},
torch.float32: {"atol": 0, "rtol": 1e-3},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
# ==============================================================================
# Definitions
# ==============================================================================
class MatmulDescriptor(Structure):
_fields_ = [("device", c_int32)]
infiniopMatmulDescriptor_t = POINTER(MatmulDescriptor)
# PyTorch implementation for matrix multiplication
def matmul(_c, beta, _a, _b, alpha):
a, b, c = _a.clone(), _b.clone(), _c.clone()
result_dtype = c.dtype
fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32))
return alpha * fp32_result.to(result_dtype) + beta * c
# The argument list should be (lib, handle, torch_device, <param list>, dtype)
# The <param list> should keep the same order as the one specified in _TEST_CASES
def test(
lib,
handle,
torch_device,
alpha,
beta,
a_shape,
b_shape,
c_shape,
a_stride=None,
b_stride=None,
c_stride=None,
dtype=torch.float16,
):
print(
f"Testing Matmul on {torch_device} with alpha:{alpha}, beta:{beta},"
f" a_shape:{a_shape}, b_shape:{b_shape}, c_shape:{c_shape},"
f" a_stride:{a_stride}, b_stride:{b_stride}, c_stride:{c_stride}, dtype:{dtype}"
)
# Initialize tensors
a = torch.rand(a_shape, dtype=dtype).to(torch_device)
b = torch.rand(b_shape, dtype=dtype).to(torch_device)
c = torch.ones(c_shape, dtype=dtype).to(torch_device)
# Compute the PyTorch reference result
ans = matmul(c, beta, a, b, alpha)
a, b, c = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])
]
a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
descriptor = infiniopMatmulDescriptor_t()
check_error(
lib.infiniopCreateMatmulDescriptor(
handle,
ctypes.byref(descriptor),
c_tensor.descriptor,
a_tensor.descriptor,
b_tensor.descriptor,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [a_tensor, b_tensor, c_tensor]:
tensor.destroyDesc(lib)
# Get workspace size and create workspace
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetMatmulWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, a.device)
# Execute infiniop matmul operator
def lib_matmul():
check_error(
lib.infiniopMatmul(
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
c_tensor.data,
a_tensor.data,
b_tensor.data,
alpha,
beta,
None,
)
)
lib_matmul()
# Validate results
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(c, ans, atol=atol, rtol=rtol)
assert torch.allclose(c, ans, atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: matmul(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_matmul(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyMatmulDescriptor(descriptor))
# ==============================================================================
# Main Execution
# ==============================================================================
if __name__ == "__main__":
args = get_args()
lib = open_lib()
lib.infiniopCreateMatmulDescriptor.restype = c_int32
lib.infiniopCreateMatmulDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopMatmulDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetMatmulWorkspaceSize.restype = c_int32
lib.infiniopGetMatmulWorkspaceSize.argtypes = [
infiniopMatmulDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopMatmul.restype = c_int32
lib.infiniopMatmul.argtypes = [
infiniopMatmulDescriptor_t,
c_void_p,
c_uint64,
c_void_p,
c_void_p,
c_void_p,
c_float,
c_float,
c_void_p,
]
lib.infiniopDestroyMatmulDescriptor.restype = c_int32
lib.infiniopDestroyMatmulDescriptor.argtypes = [
infiniopMatmulDescriptor_t,
]
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
# Execute tests
for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment