Commit c94db20d authored by zhangyue's avatar zhangyue
Browse files

issue/349 elementwise-mul

parent 19f3ada5
#ifndef __ADD_KUNLUN_KERNEL_H__
#define __ADD_KUNLUN_KERNEL_H__
namespace op::mul::kunlun {
typedef struct MulOp {
public:
static constexpr int num_inputs = 2;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
T a = inputs[0];
T b = inputs[1];
return a * b;
}
// bfloat16 特化版本(使用 float 计算精度)
inline __device__ bfloat16_t operator()(const bfloat16_t *inputs) const {
float a_f = __bfloat162float(inputs[0]);
float b_f = __bfloat162float(inputs[1]);
return __float2bfloat16(a_f * b_f);
}
} MulOp;
} // namespace op::mul::kunlun
#endif // __ADD_KUNLUN_KERNEL_H__
#ifndef __MUL_KUNLUN_API_H__
#define __MUL_KUNLUN_API_H__
#include "../../../elementwise/kunlun/elementwise_kunlun_api.h"
ELEMENTWISE_DESCRIPTOR(mul, kunlun)
#endif // __MUL_KUNLUN_API_H__
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
#include "kernel.h"
#include "mul_kunlun.h"
namespace op::elementwise::kunlun {
using MulOp = op::mul::kunlun::MulOp;
INSTANTIATE_ELEMENTWISE_KERNEL(MulOp::num_inputs, MulOp, float);
INSTANTIATE_ELEMENTWISE_KERNEL(MulOp::num_inputs, MulOp, half);
INSTANTIATE_ELEMENTWISE_KERNEL(MulOp::num_inputs, MulOp, bfloat16_t);
} // namespace op::elementwise::kunlun
namespace op::mul::kunlun {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &a_desc = input_desc_vec.at(0);
const auto &b_desc = input_desc_vec.at(1);
const auto &c_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
// create KUNLUN elementwise descriptor
CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<8, MulOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<8, MulOp, bfloat16_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<8, MulOp, float>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::mul::kunlun
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
#include "metax/mul_metax.h" #include "metax/mul_metax.h"
#endif #endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/mul_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateMulDescriptor( __C infiniStatus_t infiniopCreateMulDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -42,6 +45,9 @@ __C infiniStatus_t infiniopCreateMulDescriptor( ...@@ -42,6 +45,9 @@ __C infiniStatus_t infiniopCreateMulDescriptor(
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax); CREATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -70,6 +76,9 @@ __C infiniStatus_t infiniopGetMulWorkspaceSize(infiniopMulDescriptor_t desc, siz ...@@ -70,6 +76,9 @@ __C infiniStatus_t infiniopGetMulWorkspaceSize(infiniopMulDescriptor_t desc, siz
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax); GET(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -107,6 +116,9 @@ __C infiniStatus_t infiniopMul( ...@@ -107,6 +116,9 @@ __C infiniStatus_t infiniopMul(
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax); CALCULATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -137,6 +149,9 @@ infiniopDestroyMulDescriptor(infiniopMulDescriptor_t desc) { ...@@ -137,6 +149,9 @@ infiniopDestroyMulDescriptor(infiniopMulDescriptor_t desc) {
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax); DELETE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment