Commit 35388a54 authored by zhangyue's avatar zhangyue
Browse files

Merge branch 'main' of https://github.com/InfiniTensor/InfiniCore into p800-sub

parents 0fe0aea2 72c4dc7c
#ifndef __SWIGLU_KUNLUN_KERNEL_H__
#define __SWIGLU_KUNLUN_KERNEL_H__
namespace op::swiglu::kunlun {
/// @brief SwiGLU op kernel
typedef struct SwiGLUOp {
private:
template <typename T>
inline __device__ T sigmoid(T x) const {
return 1.0f / (1.0f + exp(-x));
}
// float version of sigmoid
inline __device__ float sigmoidf(float x) const {
return 1.0f / (1.0f + exp(-x));
}
public:
// This static number must be set in other Ops
static constexpr int num_inputs = 2;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
T up = inputs[0];
T gate = inputs[1];
T out = gate * sigmoid(gate) * up;
return out;
}
// bfloat16 特化版本(使用 float 计算精度)
inline __device__ bfloat16_t operator()(const bfloat16_t *inputs) const {
float up_f = __bfloat162float(inputs[0]);
float gate_f = __bfloat162float(inputs[1]);
float out_f = gate_f * sigmoidf(gate_f) * up_f;
return __float2bfloat16(out_f);
}
} SwiGLUOp;
} // namespace op::swiglu::kunlun
#endif // __SWIGLU_KUNLUN_KERNEL_H__
#include "../../../elementwise/kunlun/elementwise_kunlun.h" #include "../../../elementwise/kunlun/elementwise_kunlun.h"
#include "kernel.h"
#include "swiglu_kunlun.h" #include "swiglu_kunlun.h"
namespace op::elementwise::kunlun { namespace op::elementwise::kunlun {
/// @brief SwiGLU op kernel using SwiGLUOp = op::swiglu::kunlun::SwiGLUOp;
typedef struct SwiGLUOp {
private:
template <typename T>
inline __device__ T sigmoid(T x) const {
return 1.0f / (1.0f + exp(-x));
}
// float version of sigmoid
inline __device__ float sigmoidf(float x) const {
return 1.0f / (1.0f + exp(-x));
}
public:
// This static number must be set in other Ops
static constexpr int num_inputs = 2;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
T up = inputs[0];
T gate = inputs[1];
T out = gate * sigmoid(gate) * up;
return out;
}
// bfloat16 特化版本(使用 float 计算精度)
inline __device__ bfloat16_t operator()(const bfloat16_t *inputs) const {
float up_f = __bfloat162float(inputs[0]);
float gate_f = __bfloat162float(inputs[1]);
float out_f = gate_f * sigmoidf(gate_f) * up_f;
return __float2bfloat16(out_f);
}
} SwiGLUOp;
// __global__ template function instantiation // __global__ template function instantiation
INSTANTIATE_ELEMENTWISE_KERNEL(SwiGLUOp::num_inputs, SwiGLUOp, float); INSTANTIATE_ELEMENTWISE_KERNEL(SwiGLUOp::num_inputs, SwiGLUOp, float);
...@@ -82,11 +53,11 @@ infiniStatus_t Descriptor::calculate( ...@@ -82,11 +53,11 @@ infiniStatus_t Descriptor::calculate(
switch (_dtype) { switch (_dtype) {
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
return _device_info->calculate<8, op::elementwise::kunlun::SwiGLUOp, float>(_info, workspace, output, inputs, stream); return _device_info->calculate<8, SwiGLUOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
return _device_info->calculate<8, op::elementwise::kunlun::SwiGLUOp, half>(_info, workspace, output, inputs, stream); return _device_info->calculate<8, SwiGLUOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16: case INFINI_DTYPE_BF16:
return _device_info->calculate<8, op::elementwise::kunlun::SwiGLUOp, bfloat16_t>(_info, workspace, output, inputs, stream); return _device_info->calculate<8, SwiGLUOp, bfloat16_t>(_info, workspace, output, inputs, stream);
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment