"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "464d98962d28d812531f4442c98635607f399be9"
Commit 2ccf1d9d authored by Pepe's avatar Pepe Committed by wooway777
Browse files

issue/205 - 添加Sub算子

issue/205 - 添加Sub算子的头文件、CPU实现、cuda实现、及Python测试
parent 8a22f194
......@@ -19,6 +19,7 @@
#include "infiniop/ops/relu.h"
#include "infiniop/ops/rms_norm.h"
#include "infiniop/ops/rope.h"
#include "infiniop/ops/sub.h"
#include "infiniop/ops/swiglu.h"
#include "infiniop/tensor_descriptor.h"
......
#ifndef __INFINIOP_SUB_API_H__
#define __INFINIOP_SUB_API_H__
#include "../operator_descriptor.h"
typedef struct InfiniopDescriptor *infiniopSubDescriptor_t;
__C __export infiniStatus_t infiniopCreateSubDescriptor(infiniopHandle_t handle,
infiniopSubDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c,
infiniopTensorDescriptor_t a,
infiniopTensorDescriptor_t b);
__C __export infiniStatus_t infiniopGetSubWorkspaceSize(infiniopSubDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopSub(infiniopSubDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
void const *a,
void const *b,
void *stream);
__C __export infiniStatus_t infiniopDestroySubDescriptor(infiniopSubDescriptor_t desc);
#endif
......@@ -17,6 +17,7 @@ def run_tests(args):
"random_sample.py",
"rms_norm.py",
"rope.py",
"sub.py",
"swiglu.py",
"attention.py",
]:
......
#include "sub_cpu.h"
namespace op::sub::cpu {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &a_desc = input_desc_vec.at(0);
const auto &b_desc = input_desc_vec.at(1);
const auto &c_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
// create CPU elementwise descriptor
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<SubOp, fp16_t>(_info, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<SubOp, float>(_info, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<SubOp, double>(_info, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::sub::cpu
#ifndef __SUB_CPU_H__
#define __SUB_CPU_H__
#include "../../../elementwise/cpu/elementwise_cpu.h"
ELEMENTWISE_DESCRIPTOR(sub, cpu)
namespace op::sub::cpu {
typedef struct SubOp {
public:
static constexpr size_t num_inputs = 2;
template <typename T>
T operator()(const T &a, const T &b) const {
return a - b;
}
} SubOp;
} // namespace op::sub::cpu
#endif // __SUB_CPU_H__
#include "sub_cuda.cuh"
#include "sub_cuda_internal.cuh"
namespace op::sub::cuda {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::cuda::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &a_desc = input_desc_vec.at(0);
const auto &b_desc = input_desc_vec.at(1);
const auto &c_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
// create CUDA elementwise descriptor
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, SubOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, SubOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, SubOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::sub::cuda
#ifndef __SUB_CUDA_API_H__
#define __SUB_CUDA_API_H__
#include "../../../elementwise/cuda/elementwise_cuda_api.cuh"
ELEMENTWISE_DESCRIPTOR(sub, cuda)
#endif // __SUB_CUDA_API_H__
#ifndef __SUB_CUDA_H__
#define __SUB_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_fp16.h>
namespace op::sub::cuda {
typedef struct SubOp {
public:
static constexpr size_t num_inputs = 2;
template <typename T>
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
if constexpr (std::is_same_v<T, half2>) {
return __hsub2(a, b);
} else if constexpr (std::is_same_v<T, half>) {
return __hsub(a, b);
} else if constexpr (std::is_same_v<T, float>) {
return __fsub_rd(a, b);
} else {
return a - b;
}
}
} SubOp;
} // namespace op::sub::cuda
#endif // __SUB_CUDA_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/sub.h"
#ifdef ENABLE_CPU_API
#include "cpu/sub_cpu.h"
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/sub_cuda.cuh"
#endif
__C infiniStatus_t infiniopCreateSubDescriptor(
infiniopHandle_t handle,
infiniopSubDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::sub::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::sub::NAMESPACE::Descriptor **>(desc_ptr), \
c_desc, \
{a_desc, \
b_desc})
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetSubWorkspaceSize(infiniopSubDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::sub::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu)
#endif
#ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopSub(
infiniopSubDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
const void *a,
const void *b,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::sub::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, c, {a, b}, stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__C infiniStatus_t
infiniopDestroySubDescriptor(infiniopSubDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::sub::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
DELETE(INFINI_DEVICE_NVIDIA, cuda);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DELETE
}
import torch
import ctypes
from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64
from libinfiniop import (
infiniopHandle_t,
infiniopTensorDescriptor_t,
open_lib,
to_tensor,
get_test_devices,
check_error,
rearrange_if_needed,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
create_workspace,
)
from enum import Enum, auto
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# shape, a_stride, b_stride, c_stride
((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), (10, 1)),
((13, 4), (0, 1), None, None),
((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
((13, 4, 4), (4, 0, 1), (0, 4, 1), None),
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
((4, 4, 5632), None, None, None),
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
]
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_A = auto()
INPLACE_B = auto()
# Inplace options applied for each test case in _TEST_CASES_
_INPLACE = [
Inplace.OUT_OF_PLACE,
Inplace.INPLACE_A,
Inplace.INPLACE_B,
]
# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_
_TEST_CASES = [
test_case + (inplace_item,)
for test_case in _TEST_CASES_
for inplace_item in _INPLACE
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 1e-3, "rtol": 1e-3},
torch.float32: {"atol": 1e-7, "rtol": 1e-7},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class SubDescriptor(Structure):
_fields_ = [("device", c_int32)]
infiniopSubDescriptor_t = POINTER(SubDescriptor)
def sub(x, y):
return torch.sub(x, y)
def process_tensors(c, c_strides, a, a_stride, b, b_stride, inplace):
"""
rearrange the tensors if needed and apply the inplace config.
if inplace is true and the output (i.e., c) is placed to the broadcasted input,
the inplace config is ignored and out-of-place is used
"""
original_c_strides = c_strides if c_strides else c.stride()
def _rearrange(tensor, strides):
if strides and 0 in strides:
tensor.set_(tensor.untyped_storage(), 0, tensor.shape, strides)
return tensor
else:
return rearrange_if_needed(tensor, strides)
a, b, c = [
_rearrange(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_strides])
]
c = (
c
if inplace == Inplace.OUT_OF_PLACE
else (a if inplace == Inplace.INPLACE_A else b)
)
# if inplace is true and c has broadcasted config, reset it to the original unbroadcasted strides
if 0 in c.stride():
c.set_(c.untyped_storage(), 0, c.shape, original_c_strides)
return a, b, c
def test(
lib,
handle,
torch_device,
shape,
a_stride=None,
b_stride=None,
c_stride=None,
inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float16,
sync=None,
):
print(
f"Testing Sub on {torch_device} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} "
f"dtype:{dtype} inplace:{inplace}"
)
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
c = torch.rand(shape, dtype=dtype).to(torch_device)
a, b, c = process_tensors(c, c_stride, a, a_stride, b, b_stride, inplace)
ans = sub(a, b)
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
c_tensor = (
to_tensor(c, lib)
if inplace == Inplace.OUT_OF_PLACE
else (a_tensor if inplace == Inplace.INPLACE_A else b_tensor)
)
if sync is not None:
sync()
descriptor = infiniopSubDescriptor_t()
check_error(
lib.infiniopCreateSubDescriptor(
handle,
ctypes.byref(descriptor),
c_tensor.descriptor,
a_tensor.descriptor,
b_tensor.descriptor,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [a_tensor, b_tensor, c_tensor]:
tensor.destroyDesc(lib)
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetSubWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, c.device)
def lib_sub():
check_error(
lib.infiniopSub(
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
c_tensor.data,
a_tensor.data,
b_tensor.data,
None,
)
)
lib_sub()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(c, ans, atol=atol, rtol=rtol)
assert torch.allclose(c, ans, atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: sub(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_sub(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroySubDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
lib = open_lib()
lib.infiniopCreateSubDescriptor.restype = c_int32
lib.infiniopCreateSubDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopSubDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetSubWorkspaceSize.restype = c_int32
lib.infiniopGetSubWorkspaceSize.argtypes = [
infiniopSubDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopSub.restype = c_int32
lib.infiniopSub.argtypes = [
infiniopSubDescriptor_t,
c_void_p,
c_uint64,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroySubDescriptor.restype = c_int32
lib.infiniopDestroySubDescriptor.argtypes = [
infiniopSubDescriptor_t,
]
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment