Unverified Commit 681f4e1e authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #745 from InfiniTensor/issue/744

Issue/744 kunlun softplus
parents b38d5d16 f817c394
#ifndef __SOFTPLUS_KUNLUN_KERNEL_H__
#define __SOFTPLUS_KUNLUN_KERNEL_H__
namespace op::softplus::kunlun {
typedef struct SoftplusOp {
public:
static constexpr int num_inputs = 1;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
if constexpr (std::is_same_v<T, half>) {
float xf = __half2float(inputs[0]);
float out = (xf > 20.0f) ? xf : log(1 + exp(xf));
return __float2half(out);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
float xf = __bfloat162float(inputs[0]);
float out = (xf > 20.0f) ? xf : log(1 + exp(xf));
return __float2bfloat16(out);
} else {
float xf = inputs[0];
return (xf > 20.0f) ? xf : log(1 + exp(xf));
}
}
} SoftplusOp;
} // namespace op::softplus::kunlun
#endif // __SOFTPLUS_KUNLUN_KERNEL_H__
#ifndef __SOFTPLUS_KUNLUN_H__
#define __SOFTPLUS_KUNLUN_H__
#include "../../../elementwise/kunlun/elementwise_kunlun_api.h"
ELEMENTWISE_DESCRIPTOR(softplus, kunlun)
#endif
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
#include "kernel.h"
#include "softplus_kunlun.h"
namespace op::softplus::kunlun {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &x_desc = input_desc_vec.at(0);
const auto &y_shape = out_desc->shape();
const auto &x_shape = x_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(y_shape, x_shape);
// create KUNLUN elementwise descriptor
CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<8, kunlun::SoftplusOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<8, kunlun::SoftplusOp, bfloat16_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<8, kunlun::SoftplusOp, float>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::softplus::kunlun
......@@ -11,6 +11,9 @@
#ifdef ENABLE_METAX_API
#include "metax/softplus_metax.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/softplus_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateSoftplusDescriptor(
infiniopHandle_t handle,
......@@ -43,7 +46,9 @@ __C infiniStatus_t infiniopCreateSoftplusDescriptor(
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -74,7 +79,9 @@ __C infiniStatus_t infiniopGetSoftplusWorkspaceSize(infiniopSoftplusDescriptor_t
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -113,7 +120,9 @@ __C infiniStatus_t infiniopSoftplus(
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -146,7 +155,9 @@ infiniopDestroySoftplusDescriptor(infiniopSoftplusDescriptor_t desc) {
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment