Commit 19f3ada5 authored by zhangyue's avatar zhangyue
Browse files

issue/349 elementwise sub

parent 59e93ab4
#ifndef __ADD_KUNLUN_KERNEL_H__
#define __ADD_KUNLUN_KERNEL_H__
namespace op::sub::kunlun {
typedef struct SubOp {
public:
static constexpr int num_inputs = 2;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
T a = inputs[0];
T b = inputs[1];
return a - b;
}
// bfloat16 特化版本(使用 float 计算精度)
inline __device__ bfloat16_t operator()(const bfloat16_t *inputs) const {
float a_f = __bfloat162float(inputs[0]);
float b_f = __bfloat162float(inputs[1]);
return __float2bfloat16(a_f - b_f);
}
} SubOp;
} // namespace op::sub::kunlun
#endif // __ADD_KUNLUN_KERNEL_H__
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
#include "kernel.h"
#include "sub_kunlun.h"
namespace op::elementwise::kunlun {
typedef struct SubOp {
public:
static constexpr int num_inputs = 2;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
T a = inputs[0];
T b = inputs[1];
return a - b;
}
// bfloat16 特化版本(使用 float 计算精度)
inline __device__ bfloat16_t operator()(const bfloat16_t *inputs) const {
float a_f = __bfloat162float(inputs[0]);
float b_f = __bfloat162float(inputs[1]);
return __float2bfloat16(a_f - b_f);
}
} SubOp;
using SubOp = op::sub::kunlun::SubOp;
INSTANTIATE_ELEMENTWISE_KERNEL(SubOp::num_inputs, SubOp, float);
INSTANTIATE_ELEMENTWISE_KERNEL(SubOp::num_inputs, SubOp, half);
......@@ -67,11 +53,11 @@ infiniStatus_t Descriptor::calculate(
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<8, op::elementwise::kunlun::SubOp, half>(_info, workspace, output, inputs, stream);
return _device_info->calculate<8, SubOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<8, op::elementwise::kunlun::SubOp, bfloat16_t>(_info, workspace, output, inputs, stream);
return _device_info->calculate<8, SubOp, bfloat16_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<8, op::elementwise::kunlun::SubOp, float>(_info, workspace, output, inputs, stream);
return _device_info->calculate<8, SubOp, float>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment