Commit 8a49900f authored by goldenfox2025's avatar goldenfox2025
Browse files

issue/180:添加clip算子

parent 1a4cfb99
......@@ -6,6 +6,7 @@
#include "infiniop/ops/attention.h"
#include "infiniop/ops/avg_pool.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/clip.h"
#include "infiniop/ops/conv.h"
#include "infiniop/ops/expand.h"
#include "infiniop/ops/gemm.h"
......
#ifndef __INFINIOP_CLIP_API_H__
#define __INFINIOP_CLIP_API_H__
#include "../operator_descriptor.h"
typedef struct InfiniopDescriptor *infiniopClipDescriptor_t;
__C __export infiniStatus_t infiniopCreateClipDescriptor(infiniopHandle_t handle,
infiniopClipDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t min_val,
infiniopTensorDescriptor_t max_val);
__C __export infiniStatus_t infiniopGetClipWorkspaceSize(infiniopClipDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopClip(infiniopClipDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *min_val,
const void *max_val,
void *stream);
__C __export infiniStatus_t infiniopDestroyClipDescriptor(infiniopClipDescriptor_t desc);
#endif
#ifndef __CLIP_H__
#define __CLIP_H__
#include "../../elementwise/elementwise.h"
#include "../../operator.h"
/**
* @brief Define the Clip descriptor for the ternary operator
*
* This macro defines a Descriptor class for the Clip operator that inherits from InfiniopDescriptor.
* It uses the standard elementwise operation fields and methods for a ternary operator
* where min_val and max_val are tensors.
*
* @param OP The operator name (clip)
* @param NAMESPACE The namespace (cpu or cuda)
*/
#define CLIP_DESCRIPTOR(OP, NAMESPACE) \
\
namespace op::OP::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
infiniDtype_t _dtype; \
op::elementwise::ElementwiseInfo _info; \
std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \
size_t _workspace_size; \
\
public: \
Descriptor( \
infiniDtype_t dtype, \
op::elementwise::ElementwiseInfo info, \
op::elementwise::NAMESPACE::DeviceImpl *device_info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_dtype(dtype), \
_info(std::move(info)), \
_device_info(std::move(device_info)), \
_workspace_size(workspace_size) {} \
\
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *output, \
std::vector<const void *> inputs, \
void *stream) const; \
}; \
}
#endif // __CLIP_H__
#include "clip_cpu.h"
namespace op::clip::cpu {
Descriptor::~Descriptor() = default;
infiniStatus_t createClipDescriptor(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &in_desc = input_desc_vec.at(0);
const auto &min_desc = input_desc_vec.at(1);
const auto &max_desc = input_desc_vec.at(2);
const auto &out_shape = out_desc->shape();
const auto &in_shape = in_desc->shape();
const auto &min_shape = min_desc->shape();
const auto &max_shape = max_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(out_shape, in_shape);
CHECK_SAME_SHAPE(out_shape, min_shape);
CHECK_SAME_SHAPE(out_shape, max_shape);
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<ClipOp, fp16_t>(_info, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<ClipOp, float>(_info, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<ClipOp, double>(_info, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::clip::cpu
#ifndef __CLIP_CPU_H__
#define __CLIP_CPU_H__
#include "../../../elementwise/cpu/elementwise_cpu.h"
#include "../clip.h"
CLIP_DESCRIPTOR(clip, cpu)
namespace op::clip::cpu {
typedef struct ClipOp {
public:
static constexpr size_t num_inputs = 3;
template <typename T>
T operator()(const T &x, const T &min_val, const T &max_val) const {
return std::max(std::min(x, max_val), min_val);
}
} ClipOp;
// Create clip descriptor
infiniStatus_t createClipDescriptor(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec);
} // namespace op::clip::cpu
#endif // __CLIP_CPU_H__
#include "clip_cuda.cuh"
#include "clip_cuda_internal.cuh"
namespace op::clip::cuda {
Descriptor::~Descriptor() = default;
infiniStatus_t createClipDescriptor(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::cuda::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &in_desc = input_desc_vec.at(0);
const auto &min_desc = input_desc_vec.at(1);
const auto &max_desc = input_desc_vec.at(2);
const auto &out_shape = out_desc->shape();
const auto &in_shape = in_desc->shape();
const auto &min_shape = min_desc->shape();
const auto &max_shape = max_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(out_shape, in_shape);
CHECK_SAME_SHAPE(out_shape, min_shape);
CHECK_SAME_SHAPE(out_shape, max_shape);
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, ClipOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, ClipOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, ClipOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::clip::cuda
#ifndef __CLIP_CUDA_API_H__
#define __CLIP_CUDA_API_H__
#include "../../../elementwise/cuda/elementwise_cuda_api.cuh"
#include "../clip.h"
CLIP_DESCRIPTOR(clip, cuda)
namespace op::clip::cuda {
infiniStatus_t createClipDescriptor(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec);
} // namespace op::clip::cuda
#endif // __CLIP_CUDA_API_H__
#ifndef __CLIP_CUDA_H__
#define __CLIP_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_fp16.h>
namespace op::clip::cuda {
typedef struct ClipOp {
public:
static constexpr size_t num_inputs = 3;
template <typename T>
__device__ __forceinline__ T operator()(const T &x, const T &min_val, const T &max_val) const {
if constexpr (std::is_same_v<T, half2>) {
return __hmax2(__hmin2(x, max_val), min_val);
} else if constexpr (std::is_same_v<T, half>) {
return __hmax(__hmin(x, max_val), min_val);
} else if constexpr (std::is_same_v<T, float>) {
return fmaxf(fminf(x, max_val), min_val);
} else if constexpr (std::is_same_v<T, double>) {
return fmax(fmin(x, max_val), min_val);
} else {
return std::max(std::min(x, max_val), min_val);
}
}
} ClipOp;
} // namespace op::clip::cuda
#endif // __CLIP_CUDA_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/clip.h"
#ifdef ENABLE_CPU_API
#include "cpu/clip_cpu.h"
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/clip_cuda.cuh"
#endif
__C infiniStatus_t infiniopCreateClipDescriptor(
infiniopHandle_t handle,
infiniopClipDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t min_val,
infiniopTensorDescriptor_t max_val) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::clip::NAMESPACE::createClipDescriptor( \
handle, \
reinterpret_cast<op::clip::NAMESPACE::Descriptor **>(desc_ptr), \
y, \
{x, min_val, max_val})
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetClipWorkspaceSize(infiniopClipDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::clip::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu)
#endif
#ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda)
#endif
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopClip(
infiniopClipDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *min_val,
const void *max_val,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::clip::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, {x, min_val, max_val}, stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__C infiniStatus_t
infiniopDestroyClipDescriptor(infiniopClipDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::clip::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
DELETE(INFINI_DEVICE_NVIDIA, cuda);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DELETE
}
#!/usr/bin/env python3
import torch
import ctypes
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from libinfiniop import (
infiniopHandle_t,
infiniopTensorDescriptor_t,
open_lib,
to_tensor,
get_test_devices,
check_error,
rearrange_if_needed,
create_workspace,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
)
from enum import Enum, auto
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# shape, x_stride, y_stride, min_val, max_val
# 基本形状测试
((10,), None, None, -1.0, 1.0),
((5, 10), None, None, -1.0, 1.0),
((2, 3, 4), None, None, -1.0, 1.0),
# 不同的min_val和max_val
((10,), None, None, 0.0, 2.0),
((5, 10), None, None, 0.0, 2.0),
((2, 3, 4), None, None, 0.0, 2.0),
((10,), None, None, -2.0, 0.0),
((5, 10), None, None, -2.0, 0.0),
((2, 3, 4), None, None, -2.0, 0.0),
# 奇怪形状测试
((7, 13), None, None, -1.0, 1.0), # 质数维度
((3, 5, 7), None, None, -1.0, 1.0), # 三维质数
# 非标准形状测试
((1, 1), None, None, -1.0, 1.0), # 最小形状
((100, 100), None, None, -1.0, 1.0), # 大形状
((16, 16, 16), None, None, -1.0, 1.0), # 大三维
# 极端值测试
((10,), None, None, -1000.0, 1000.0), # 大范围
((10,), None, None, -0.001, 0.001), # 小范围
((10,), None, None, 0.0, 0.0), # min=max
# 特殊形状测试
((0,), None, None, -1.0, 1.0), # 空张量
((1, 0), None, None, -1.0, 1.0), # 空维度
# 带stride的测试用例
((5, 10), (10, 1), None, -1.0, 1.0), # 行优先
((5, 10), (1, 5), None, -1.0, 1.0), # 列优先
((5, 10), (10, 1), (10, 1), -1.0, 1.0), # 输入输出都有stride
((5, 10), (1, 5), (1, 5), -1.0, 1.0), # 输入输出都有stride
((5, 10), (10, 1), (1, 5), -1.0, 1.0), # 输入输出有不同的stride
]
# 开发机cpu不支持fp16 没有测试
_TENSOR_DTYPES = [torch.float32, torch.float64]
_TOLERANCE_MAP = {
torch.float32: {"atol": 1e-7, "rtol": 1e-6},
torch.float64: {"atol": 1e-10, "rtol": 1e-10},
}
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_X = auto()
_INPLACE = [
Inplace.INPLACE_X,
Inplace.OUT_OF_PLACE,
]
_TEST_CASES = [
test_case + (inplace_item,)
for test_case in _TEST_CASES_
for inplace_item in _INPLACE
]
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
class ClipDescriptor(Structure):
_fields_ = [("device_type", c_int32), ("device_id", c_int32)]
infiniopClipDescriptor_t = POINTER(ClipDescriptor)
def clip(x, min_val, max_val):
return torch.clamp(x, min_val, max_val)
def create_tensor_with_stride(shape, stride, dtype, device):
"""Create a tensor with specific stride without using view() that might cause errors."""
x = torch.rand(shape, dtype=dtype, device=device) * 4.0 - 2.0 # Range: [-2, 2]
if stride is None:
return x
if len(shape) == 2 and len(stride) == 2:
if stride == (shape[1], 1):
return x.contiguous()
elif stride == (1, shape[0]):
return x.transpose(0, 1).contiguous().transpose(0, 1)
else:
y = torch.zeros(shape, dtype=dtype, device=device)
for i in range(shape[0]):
for j in range(shape[1]):
y[i, j] = x[i, j]
return y.contiguous()
return x
def test(
lib,
handle,
torch_device,
shape,
x_stride=None,
y_stride=None,
min_val=-1.0,
max_val=1.0,
inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float32,
):
print(
f"Testing Clip on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} "
f"min_val:{min_val} max_val:{max_val} dtype:{dtype} inplace:{inplace}"
)
x = create_tensor_with_stride(shape, x_stride, dtype, torch_device)
# Create tensor versions of min_val and max_val with the same shape as x
min_tensor = torch.full(shape, min_val, dtype=dtype, device=torch_device)
max_tensor = torch.full(shape, max_val, dtype=dtype, device=torch_device)
ans = clip(x, min_val, max_val)
# 确保张量是连续的,然后再重新排列
x = x.contiguous()
min_tensor = min_tensor.contiguous()
max_tensor = max_tensor.contiguous()
x = rearrange_if_needed(x, x_stride)
min_tensor = rearrange_if_needed(min_tensor, x_stride)
max_tensor = rearrange_if_needed(max_tensor, x_stride)
x_tensor = to_tensor(x, lib)
min_tensor_desc = to_tensor(min_tensor, lib)
max_tensor_desc = to_tensor(max_tensor, lib)
if inplace == Inplace.INPLACE_X:
y = x
y_tensor = x_tensor
else:
y = torch.zeros(shape, dtype=dtype).to(torch_device)
y = y.contiguous() # 确保张量是连续的
y = rearrange_if_needed(y, y_stride)
y_tensor = to_tensor(y, lib)
descriptor = infiniopClipDescriptor_t()
check_error(
lib.infiniopCreateClipDescriptor(
handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor,
min_tensor_desc.descriptor, max_tensor_desc.descriptor
)
)
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetClipWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = create_workspace(workspace_size.value, x.device)
def lib_clip():
check_error(
lib.infiniopClip(
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
y_tensor.data,
x_tensor.data,
min_tensor_desc.data,
max_tensor_desc.data,
None,
)
)
lib_clip()
# Now we can destroy the tensor descriptors
x_tensor.destroyDesc(lib)
min_tensor_desc.destroyDesc(lib)
max_tensor_desc.destroyDesc(lib)
if inplace != Inplace.INPLACE_X:
y_tensor.destroyDesc(lib)
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG or not torch.allclose(y, ans, atol=atol, rtol=rtol):
print("\nExpected:")
print(ans)
print("\nActual:")
print(y)
print("\nDifference:")
print(torch.abs(y - ans))
print("\nMax difference:", torch.max(torch.abs(y - ans)).item())
debug(y, ans, atol=atol, rtol=rtol)
assert torch.allclose(y, ans, atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: clip(x, min_val, max_val), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_clip(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyClipDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
lib = open_lib()
lib.infiniopCreateClipDescriptor.restype = c_int32
lib.infiniopCreateClipDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopClipDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetClipWorkspaceSize.restype = c_int32
lib.infiniopGetClipWorkspaceSize.argtypes = [
infiniopClipDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopClip.restype = c_int32
lib.infiniopClip.argtypes = [
infiniopClipDescriptor_t,
c_void_p,
c_uint64,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyClipDescriptor.restype = c_int32
lib.infiniopDestroyClipDescriptor.argtypes = [
infiniopClipDescriptor_t,
]
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment