Commit 120b4348 authored by zhangyue's avatar zhangyue
Browse files

issue/9: define macro for different kernel

parent d7385575
......@@ -120,29 +120,27 @@ __aicore__ inline void SwigluKernel<T>::Process() {
}
}
__global__ __aicore__ void swiglu_kernel_half(GM_ADDR c, GM_ADDR a, GM_ADDR b,
int64_t batch, int64_t seq, int64_t hd,
int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b,
int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b) {
SwigluKernel<half> op;
op.Init(c, a, b,
batch, seq, hd,
stride_batch_c, stride_batch_a, stride_batch_b,
stride_seq_c, stride_seq_a, stride_seq_b);
op.Process();
}
#define DEFINE_SWIGLU_KERNEL(KERNEL_NAME, TYPE) \
__global__ __aicore__ void KERNEL_NAME(GM_ADDR c, GM_ADDR a, GM_ADDR b, \
int64_t batch, int64_t seq, int64_t hd, \
int64_t stride_batch_c, \
int64_t stride_batch_a, \
int64_t stride_batch_b, \
int64_t stride_seq_c, \
int64_t stride_seq_a, \
int64_t stride_seq_b) { \
SwigluKernel<TYPE> op; \
op.Init(c, a, b, \
batch, seq, hd, \
stride_batch_c, stride_batch_a, stride_batch_b, \
stride_seq_c, stride_seq_a, stride_seq_b); \
op.Process(); \
}
__global__ __aicore__ void swiglu_kernel_float(GM_ADDR c, GM_ADDR a, GM_ADDR b,
int64_t batch, int64_t seq, int64_t hd,
int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b,
int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b) {
SwigluKernel<float> op;
op.Init(c, a, b,
batch, seq, hd,
stride_batch_c, stride_batch_a, stride_batch_b,
stride_seq_c, stride_seq_a, stride_seq_b);
op.Process();
}
DEFINE_SWIGLU_KERNEL(swiglu_kernel_half, half)
DEFINE_SWIGLU_KERNEL(swiglu_kernel_float, float)
#undef DEFINE_SWIGLU_KERNEL
extern "C" infiniStatus_t swiglu_kernel_launch(
void *c, void *a, void *b,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment