Commit c2e87202 authored by Catheriany's avatar Catheriany
Browse files

Merge remote-tracking branch 'origin/main' into issue/142

parents 41818f84 c203635b
#ifndef __RMS_NORM_CUDA_KERNEL_H__
#define __RMS_NORM_CUDA_KERNEL_H__
#include "../../../devices/cuda/cuda_common.cuh"
#include <cub/block/block_reduce.cuh>
#include "../../../devices/cuda/cuda_kernel_common.cuh"
#include "../../../reduce/cuda/reduce.cuh"
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tweight, typename Tcompute>
INFINIOP_CUDA_KERNEL rmsnormBlock(
......
#ifndef __RMS_NORM_KUNLUN_KERNEL_XPU__
#define __RMS_NORM_KUNLUN_KERNEL_XPU__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../reduce/kunlun/reduce_kunlun.h"
using namespace device::kunlun::kernel;
// Element wise mul used in x * w
static inline __device__ void elementwiseMulRms(float *x, float *w, float *y, int count, float rms) {
int remain = count % 16;
int offset_last = count - remain;
// y[i] = w[i] * x[i] * rms for remainder
for (int i = offset_last; i < count; i++) {
*(y + i) = *(w + i) * *(x + i) * rms;
}
mfence();
float32x16_t v_x;
float32x16_t v_w;
// Do x * w * rms
for (int i = 0; i < offset_last; i += 16) {
v_x = vload_lm_float32x16_mz(x + i);
v_w = vload_lm_float32x16_mz(w + i);
v_x = vvmul_float32x16(v_x, v_w);
v_x = svmul_float32x16(rms, v_x);
vstore_lm_float32x16((y + i), v_x);
mfence();
}
}
// RmsNorm main kernel func
// kunlun2 has 8 cluster and 64 core
// Call it by rmsnorm<<<8, 32, stream>>>()
__global__ void rmsNormKernelF32(float *y, long stride_y, const float *x, long stride_x, const float *w, int m, int n, float epsilon) {
// ncores in a cluster
int ncores = core_num();
// get cid of current core
int cid = core_id();
if (cid >= ncores) {
return;
}
// Divide m rows into all clusters equally
// if m % cluster_num() != 0, cluster_id < m % cluster_num() do 1 row more
// [m_start, m_end) is the range of m dim in current cluster
int m_start = m / cluster_num() * cluster_id() + min(m % cluster_num(), cluster_id());
int m_end = m_start + (m / cluster_num()) + (cluster_id() < (m % cluster_num()));
// max_nn is the max number of elements calculated on one core
const int max_nn = 1024;
// max_mm is the max number of rows calculated on one cluster
const int max_mm = 1024;
// LM cache for reduce
__local__ float x_local[max_nn];
// sm_output is shared mem cache for reduce
__shared__ float sm_output[max_mm];
// LM cache for elementwise mul
__local__ float y_local[max_nn];
__local__ float w_local[max_nn];
while (m_start < m_end) {
// init sm_output
for (int i = cid; i < m_end - m_start; i += ncores) {
sm_output[i] = 0.0f;
}
mfence();
sync_cluster();
// mm is the number of rows on current cluster
int mm = min(max_mm, m_end - m_start);
// each row will be devided to several blocks
// total_block is the number of blocks calculated on current cluster
// curr_block is the block calculated on current core
int total_block = mm * roundup_div(n, max_nn);
for (int curr_block = cid; curr_block < total_block; curr_block += ncores) {
// curr_m is the row of curr_block;
// curr_n_start is the first element of current row
// curr_nn is the number of elements of curr_block
int curr_m = curr_block % mm + m_start;
int curr_n_start = (curr_block / mm) * max_nn;
int curr_nn = min(max_nn, n - curr_n_start);
auto x_ptr = x + curr_m * stride_x + curr_n_start;
GM2LM(x_ptr, x_local, curr_nn * sizeof(float));
// do reduce
float ss = op::common_kunlun::reduce_op::sumSquaredF32(x_local, curr_nn);
atomicAddF32(&sm_output[curr_m - m_start], ss);
}
mfence();
sync_cluster();
// do elementwise mul for every line
for (int blk = cid; blk < total_block; blk += ncores) {
int m = blk % mm + m_start;
int n_start = (blk / mm) * max_nn;
int nn = min(max_nn, n - n_start);
auto x_ptr = x + m * stride_x + n_start;
auto w_ptr = w + n_start;
GM2LM(x_ptr, x_local, nn * sizeof(float));
GM2LM(w_ptr, w_local, nn * sizeof(float));
float ss = SM2REG_atomic(sm_output + m - m_start);
float rms = 1.0f / sqrt(ss / n + epsilon);
elementwiseMulRms(x_local, w_local, y_local, nn, rms);
mfence();
auto y_ptr = y + m * stride_y + n_start;
LM2GM(y_local, y_ptr, nn * sizeof(float));
}
mfence();
sync_cluster();
m_start += max_mm;
}
}
void rmsNormF32(void *y, long stride_y, const void *x, long stride_x, const void *w, int m, int n, float epsilon, XPUStream stream) {
rmsNormKernelF32<<<8, 32, stream>>>((float *)y, stride_y, (const float *)x, stride_x, (const float *)w, m, n, epsilon);
}
#endif
#include "rms_norm_kunlun.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include <memory>
#include <stdint.h>
void rmsNormF32(void *y, long stride_y, const void *x, long stride_x, const void *w, int m, int n, float epsilon, XPUStream stream);
namespace op::rms_norm::kunlun {
struct Descriptor::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (info.ndim() != 2) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
*desc_ptr = new Descriptor(
new Descriptor::Opaque{static_cast<device::kunlun::Handle *>(handle)->internal()},
info,
0,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t launchKernel(
int m, int n,
void *y, infiniDtype_t atype, ptrdiff_t stride_y,
const void *x, ptrdiff_t stride_x,
const void *w, infiniDtype_t wtype,
float epsilon,
kunlunStream_t stream) {
if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
rmsNormF32(y, static_cast<long>(stride_y), x, static_cast<long>(stride_x), w, m, n, epsilon, stream);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y, const void *x, const void *w, void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto stride_x = _info.x_strides[0];
auto stride_y = _info.y_strides[0];
int n = static_cast<int>(_info.dim());
int m = static_cast<int>(_info.shape[0]);
launchKernel(m, n, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, reinterpret_cast<kunlunStream_t>(stream));
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::kunlun
#ifndef __RMS_NORM_KUNLUN_H__
#define __RMS_NORM_KUNLUN_H__
#include "../rms_norm.h"
DESCRIPTOR(kunlun)
#endif
#ifndef __RMS_NORM_MACA_CUH__
#define __RMS_NORM_MACA_CUH__
#include "../rms_norm.h"
DESCRIPTOR(maca)
#endif
#include "../../../devices/maca/common_maca.h"
#include "../cuda/rms_norm_kernel.cuh"
#include "rms_norm_maca.cuh"
namespace op::rms_norm::maca {
struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
// only support contiguous last dimension
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::maca::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y,
const void *x, ptrdiff_t stride_x,
const void *w, infiniDtype_t wtype,
float epsilon,
hcStream_t maca_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormBlock<BLOCK_SIZE, Tdata, Tweight, Tcompute><<<batch_size, BLOCK_SIZE, 0, maca_stream>>>( \
reinterpret_cast<Tdata *>(y), \
stride_y, \
reinterpret_cast<const Tdata *>(x), \
stride_x, \
reinterpret_cast<const Tweight *>(w), \
dim, \
epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto stride_x = _info.x_strides[0];
auto stride_y = _info.y_strides[0];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
auto maca_stream = reinterpret_cast<hcStream_t>(stream);
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, maca_stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::maca
#ifndef __RMS_NORM_MUSA_CUH__
#define __RMS_NORM_MUSA_CUH__
#include "../rms_norm.h"
DESCRIPTOR(musa)
#endif
#include "../../../devices/musa/common_musa.h"
#include "../cuda/rms_norm_kernel.cuh"
#include "rms_norm_musa.cuh"
namespace op::rms_norm::musa {
struct Descriptor::Opaque {
std::shared_ptr<device::musa::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
// only support contiguous last dimension
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::musa::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y,
const void *x, ptrdiff_t stride_x,
const void *w, infiniDtype_t wtype,
float epsilon,
musaStream_t musa_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormBlock<BLOCK_SIZE, Tdata, Tweight, Tcompute><<<batch_size, BLOCK_SIZE, 0, musa_stream>>>( \
reinterpret_cast<Tdata *>(y), \
stride_y, \
reinterpret_cast<const Tdata *>(x), \
stride_x, \
reinterpret_cast<const Tweight *>(w), \
dim, \
epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto stride_x = _info.x_strides[0];
auto stride_y = _info.y_strides[0];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::musa
......@@ -11,6 +11,15 @@
#ifdef ENABLE_ASCEND_API
#include "ascend/rms_norm_aclnn.h"
#endif
#ifdef ENABLE_METAX_API
#include "maca/rms_norm_maca.cuh"
#endif
#ifdef ENABLE_MOORE_API
#include "musa/rms_norm_musa.cuh"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/rms_norm_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateRMSNormDescriptor(
infiniopHandle_t handle,
......@@ -37,6 +46,9 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
#ifdef ENABLE_CUDA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda)
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangCreateRMSNormDescriptor((BangHandle_t)handle, (RMSNormBangDescriptor_t *)desc_ptr, y_desc, x_desc, w_desc, epsilon);
......@@ -45,15 +57,11 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
#ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
return macaCreateRMSNormDescriptor((MacaHandle_t)handle, (RMSNormMacaDescriptor_t *)desc_ptr, y_desc, x_desc, w_desc, epsilon);
}
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, maca)
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaCreateRMSNormDescriptor((MusaHandle_t)handle, (RMSNormMusaDescriptor_t *)desc_ptr, y_desc, x_desc, w_desc, epsilon);
}
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, musa)
#endif
}
......@@ -76,6 +84,9 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
#ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda)
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangGetRMSNormWorkspaceSize((RMSNormBangDescriptor_t)desc, size);
......@@ -84,15 +95,11 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
#ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
return macaGetRMSNormWorkspaceSize((RMSNormMacaDescriptor_t)desc, size);
}
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, maca)
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaGetRMSNormWorkspaceSize((RMSNormMusaDescriptor_t)desc, size);
}
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, musa)
#endif
}
......@@ -116,6 +123,9 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
#ifdef ENABLE_CUDA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda)
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangRMSNorm((RMSNormBangDescriptor_t)desc, workspace, workspace_size, y, x, w, stream);
......@@ -124,15 +134,11 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
#ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
return macaRMSNorm((RMSNormMacaDescriptor_t)desc, workspace, workspace_size, y, x, w, stream);
}
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, maca)
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaRMSNorm((RMSNormMusaDescriptor_t)desc, workspace, workspace_size, y, x, w, stream);
}
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, musa)
#endif
}
......@@ -155,6 +161,9 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
#ifdef ENABLE_CUDA_API
DESTROY(INFINI_DEVICE_NVIDIA, cuda)
#endif
#ifdef ENABLE_KUNLUN_API
DESTROY(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangDestroyRMSNormDescriptor((RMSNormBangDescriptor_t)desc);
......@@ -163,15 +172,11 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
#ifdef ENABLE_ASCEND_API
DESTROY(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
return macaDestroyRMSNormDescriptor((RMSNormMacaDescriptor_t)desc);
}
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, maca)
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaDestroyRMSNormDescriptor((RMSNormMusaDescriptor_t)desc);
}
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, musa)
#endif
}
......
#include "rope_ascend.h"
#include "../../../devices/ascend/common_ascend.h"
namespace op::rope::ascend {
Descriptor::~Descriptor()
= default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle_ascned = reinterpret_cast<device::ascend::Handle *>(handle);
auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(result);
size_t workspace_size = 0;
*desc_ptr = new Descriptor(std::move(result.take()), workspace_size, nullptr, handle_ascned->device, handle_ascned->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
CHECK_DTYPE(_info.data_type, INFINI_DTYPE_F32, INFINI_DTYPE_F16);
auto data_type = _info.data_type;
auto pos_type = _info.pos_type;
auto seq_len = _info.seqlen;
auto nhead = _info.nhead;
auto dhead = _info.dhead;
auto y_stride_seqlen = _info.y_stride_seqlen;
auto y_stride_nhead = _info.y_stride_nhead;
auto x_stride_seqlen = _info.x_stride_seqlen;
auto x_stride_nhead = _info.x_stride_nhead;
return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, pos_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream);
}
} // namespace op::rope::ascend
#ifndef __ACLNN_ROPE_H__
#define __ACLNN_ROPE_H__
#include "../rope.h"
extern "C" infiniStatus_t rope_kernel_launch(
void *y,
void *x,
void *pos,
void *sin,
void *cos,
size_t seq_len,
size_t nhead,
size_t dhead,
infiniDtype_t data_type,
infiniDtype_t pos_type,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead,
void *stream);
DESCRIPTOR(ascend)
#endif // __ACLNN_ROPE_H__
#include "../../../devices/ascend/ascend_kernel_common.h"
using namespace AscendC;
template <typename T, typename U>
class RoPEKernel {
public:
__aicore__ inline RoPEKernel() {}
// Init op
// pos position vector
// x input tensor
// y output tensor
// tensor shape [nt, nh, dh]
// make block_num = nh, tile_len = dh
__aicore__ inline void init(GM_ADDR y,
GM_ADDR x,
GM_ADDR pos,
GM_ADDR sin,
GM_ADDR cos,
size_t dh,
ptrdiff_t st_ynt,
ptrdiff_t st_ynh,
ptrdiff_t st_xnt,
ptrdiff_t st_xnh);
__aicore__ inline void process(size_t seq_len);
private:
// Copy a tile into UB
__aicore__ inline void copyIn(size_t i);
__aicore__ inline void compute(size_t i);
__aicore__ inline void copyOut(size_t i);
private:
TPipe pipe;
TQue<QuePosition::VECIN, BUFFER_NUM> _in_que;
TQue<QuePosition::VECIN, BUFFER_NUM> _sin_que;
TQue<QuePosition::VECIN, BUFFER_NUM> _cos_que;
TQue<QuePosition::VECOUT, BUFFER_NUM> _out_que;
TBuf<TPosition::VECCALC> _tmp_odd_buf;
TBuf<TPosition::VECCALC> _tmp_even_buf;
TBuf<TPosition::VECCALC> _tmp_odd_buf1;
TBuf<TPosition::VECCALC> _tmp_odd_buf2;
TBuf<TPosition::VECCALC> _tmp_even_buf1;
TBuf<TPosition::VECCALC> _tmp_even_buf2;
GlobalTensor<T> _x_gm, _y_gm;
GlobalTensor<U> _p_gm;
GlobalTensor<T> _sin_gm;
GlobalTensor<T> _cos_gm;
size_t _block_idx;
size_t _tile_len;
size_t _copy_len;
size_t _half_copy_len;
// stridey[_st_ynt, _st_ynh, 1]
ptrdiff_t _st_ynt;
ptrdiff_t _st_ynh;
// stridex[_st_xnt, _st_xnh, 1]
ptrdiff_t _st_xnt;
ptrdiff_t _st_xnh;
};
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::init(GM_ADDR y,
GM_ADDR x,
GM_ADDR pos,
GM_ADDR sin,
GM_ADDR cos,
size_t dh,
ptrdiff_t st_ynt,
ptrdiff_t st_ynh,
ptrdiff_t st_xnt,
ptrdiff_t st_xnh) {
this->_tile_len = dh;
this->_st_ynt = st_ynt;
this->_st_ynh = st_ynh;
this->_st_xnt = st_xnt;
this->_st_xnh = st_xnh;
_copy_len = alignTileLen<T>(dh, BYTE_ALIGN);
_half_copy_len = alignTileLen<T>(dh, BYTE_ALIGN);
_block_idx = GetBlockIdx();
// Init global buffer
_x_gm.SetGlobalBuffer((__gm__ T *)x);
_p_gm.SetGlobalBuffer((__gm__ U *)pos);
_sin_gm.SetGlobalBuffer((__gm__ T *)sin);
_cos_gm.SetGlobalBuffer((__gm__ T *)cos);
_y_gm.SetGlobalBuffer((__gm__ T *)y);
// Init Queue buffer
pipe.InitBuffer(_in_que, BUFFER_NUM, _copy_len * sizeof(T));
pipe.InitBuffer(_out_que, BUFFER_NUM, _tile_len * sizeof(T));
pipe.InitBuffer(_sin_que, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(_cos_que, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(_tmp_odd_buf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_even_buf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_odd_buf1, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_odd_buf2, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_even_buf1, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_even_buf2, _tile_len / 2 * sizeof(T));
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::copyIn(size_t i) {
LocalTensor<T> input_ub = _in_que.AllocTensor<T>();
LocalTensor<T> sin_ub = _sin_que.AllocTensor<T>();
LocalTensor<T> cos_ub = _cos_que.AllocTensor<T>();
// Get idx of current tile in total input
auto idx = i * _st_xnt + _block_idx * _st_xnh;
// Copy tile current tile into UB
DataCopy(input_ub, _x_gm[idx], _copy_len);
// Copy sin cos tile
auto pos_idx = _p_gm(i);
DataCopy(sin_ub, _sin_gm[pos_idx * _tile_len / 2], _half_copy_len);
DataCopy(cos_ub, _cos_gm[pos_idx * _tile_len / 2], _half_copy_len);
// Push in operands
_in_que.EnQue(input_ub);
_sin_que.EnQue(sin_ub);
_cos_que.EnQue(cos_ub);
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::compute(size_t i) {
LocalTensor<T> input_ub = _in_que.DeQue<T>();
LocalTensor<T> sin_ub = _sin_que.DeQue<T>();
LocalTensor<T> cos_ub = _cos_que.DeQue<T>();
LocalTensor<T> output_ub = _out_que.AllocTensor<T>();
LocalTensor<T> tmp_odd = _tmp_odd_buf.Get<T>();
LocalTensor<T> tmp_even = _tmp_even_buf.Get<T>();
LocalTensor<T> tmp_odd1 = _tmp_odd_buf1.Get<T>();
LocalTensor<T> tmp_odd2 = _tmp_odd_buf2.Get<T>();
LocalTensor<T> tmp_even1 = _tmp_even_buf1.Get<T>();
LocalTensor<T> tmp_even2 = _tmp_even_buf2.Get<T>();
// separate odd and even bit elements
uint64_t rsvdCnt = 0;
GatherMaskParams gMaskParams = {
1,
static_cast<uint16_t>((_tile_len * sizeof(T) + 255) / 256), // no more than 256(<=255)
8,
8,
};
GatherMask<T>(tmp_odd, input_ub, 1, false, 0, gMaskParams, rsvdCnt);
GatherMask<T>(tmp_even, input_ub, 2, false, 0, gMaskParams, rsvdCnt);
PipeBarrier<PIPE_V>();
// compute odd bit elements
// y_odd = x_odd * cos - x_even * sin
Mul<T>(tmp_odd1, tmp_odd, cos_ub, _tile_len / 2);
Mul<T>(tmp_odd2, tmp_even, sin_ub, _tile_len / 2);
PipeBarrier<PIPE_V>();
Sub<T>(tmp_odd1, tmp_odd1, tmp_odd2, _tile_len / 2);
// compute even bit elements
// y_even = x_odd * sin + x_even * cos
Mul<T>(tmp_even1, tmp_odd, sin_ub, _tile_len / 2);
Mul<T>(tmp_even2, tmp_even, cos_ub, _tile_len / 2);
PipeBarrier<PIPE_V>();
Add<T>(tmp_even1, tmp_even1, tmp_even2, _tile_len / 2);
// combine odd and even bit elements
for (uint32_t j = 0; j < _tile_len / 2; j += 1) {
output_ub(j * 2) = tmp_odd1(j);
output_ub(j * 2 + 1) = tmp_even1(j);
}
_out_que.EnQue<T>(output_ub);
_in_que.FreeTensor(input_ub);
_sin_que.FreeTensor(sin_ub);
_cos_que.FreeTensor(cos_ub);
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::copyOut(size_t i) {
LocalTensor<T> output_ub = _out_que.DeQue<T>();
auto idy = i * _st_ynt + _block_idx * _st_ynh;
DataCopyExtParams params = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
DataCopyPad(_y_gm[idy], output_ub, params);
_out_que.FreeTensor(output_ub);
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::process(size_t seq_len) {
for (size_t i = 0; i < seq_len; ++i) {
copyIn(i);
compute(i);
copyOut(i);
}
}
#define ROPE_KERNEL_INIT_ARGS y, x, pos, sin, cos, dhead, \
y_stride_seqlen, y_stride_nhead, \
x_stride_seqlen, x_stride_nhead
#define CASE_POSTYPE(POS_TYPE_ENUM, TYPE, POS_T) \
case POS_TYPE_ENUM: { \
RoPEKernel<TYPE, POS_T> op; \
op.init(ROPE_KERNEL_INIT_ARGS); \
op.process(seq_len); \
break; \
}
#define ROPE_KERNEL(TYPE, POSTYPE) \
switch (POSTYPE) { \
CASE_POSTYPE(INFINI_DTYPE_I8, TYPE, int8_t) \
CASE_POSTYPE(INFINI_DTYPE_I16, TYPE, int16_t) \
CASE_POSTYPE(INFINI_DTYPE_I32, TYPE, int32_t) \
CASE_POSTYPE(INFINI_DTYPE_I64, TYPE, int64_t) \
CASE_POSTYPE(INFINI_DTYPE_U8, TYPE, uint8_t) \
CASE_POSTYPE(INFINI_DTYPE_U16, TYPE, uint16_t) \
CASE_POSTYPE(INFINI_DTYPE_U32, TYPE, uint32_t) \
CASE_POSTYPE(INFINI_DTYPE_U64, TYPE, uint64_t) \
default: \
break; \
}
#define DEFINE_ROPE_KERNEL(KERNEL_NAME, TYPE) \
__global__ __aicore__ void KERNEL_NAME(GM_ADDR y, \
GM_ADDR x, \
GM_ADDR pos, \
GM_ADDR sin, \
GM_ADDR cos, \
size_t seq_len, \
size_t dhead, \
ptrdiff_t y_stride_seqlen, \
ptrdiff_t y_stride_nhead, \
ptrdiff_t x_stride_seqlen, \
ptrdiff_t x_stride_nhead, \
int32_t pos_type) { \
ROPE_KERNEL(TYPE, pos_type) \
}
DEFINE_ROPE_KERNEL(rope_kernel_float, float)
DEFINE_ROPE_KERNEL(rope_kernel_half, half)
#undef DEFINE_ROPE_KERNEL
#undef ROPE_KERNEL
#undef CASE_POSTYPE
#undef ROPE_KERNEL_INIT_ARGS
extern "C" infiniStatus_t rope_kernel_launch(
void *y,
void *x,
void *pos,
void *sin,
void *cos,
size_t seq_len,
size_t nhead,
size_t dhead,
infiniDtype_t dtype,
infiniDtype_t pos_type,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead,
void *stream) {
#define LAUNCH_ROPE_KERNEL(DTYPE_ENUM, KERNEL_NAME) \
case DTYPE_ENUM: \
KERNEL_NAME<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, \
seq_len, \
dhead, \
y_stride_seqlen, \
y_stride_nhead, \
x_stride_seqlen, \
x_stride_nhead, \
pos_type); \
return INFINI_STATUS_SUCCESS;
switch (dtype) {
LAUNCH_ROPE_KERNEL(INFINI_DTYPE_F16, rope_kernel_half)
LAUNCH_ROPE_KERNEL(INFINI_DTYPE_F32, rope_kernel_float)
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
#include "rope_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
namespace op::rope::cpu {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
nullptr,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPE(const RoPEInfo &info,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table) {
#pragma omp parallel for
for (ptrdiff_t h = 0; h < ptrdiff_t(info.nhead); h++) {
for (size_t tok = 0; tok < info.seqlen; tok++) {
size_t x_offset = tok * info.x_stride_seqlen + h * info.x_stride_nhead;
size_t y_offset = tok * info.y_stride_seqlen + h * info.y_stride_nhead;
size_t pos_id = size_t(pos_ids[tok]);
size_t table_offset = pos_id * info.table_dim;
for (size_t i = 0; i < info.table_dim; i++) {
size_t pos0 = 2 * i;
size_t pos1 = 2 * i + 1;
if constexpr (std::is_same<Tdata, fp16_t>::value) {
float x0 = utils::cast<float>(x[x_offset + pos0]),
x1 = utils::cast<float>(x[x_offset + pos1]),
sin__ = utils::cast<float>(sin_table[table_offset + i]),
cos__ = utils::cast<float>(cos_table[table_offset + i]);
y[y_offset + pos0] = utils::cast<fp16_t>(x0 * cos__ - x1 * sin__);
y[y_offset + pos1] = utils::cast<fp16_t>(x0 * sin__ + x1 * cos__);
} else {
Tdata x0 = x[x_offset + pos0],
x1 = x[x_offset + pos1],
sin__ = sin_table[table_offset + i],
cos__ = cos_table[table_offset + i];
y[y_offset + pos0] = x0 * cos__ - x1 * sin__;
y[y_offset + pos1] = x0 * sin__ + x1 * cos__;
}
}
}
}
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE(TDATA, TINDEX) \
calculateRoPE(_info, (TDATA *)y, (const TDATA *)x, (const TINDEX *)pos_ids, (const TDATA *)sin_table, (const TDATA *)cos_table)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(fp16_t);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
ROPE_TYPE(double);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope::cpu
#ifndef __INFINIOP_ROPE_CPU_H__
#define __INFINIOP_ROPE_CPU_H__
#include "../rope.h"
DESCRIPTOR(cpu)
#endif // __INFINIOP_ROPE_CPU_H__
#include "../../../devices/cuda/cuda_common.cuh"
#include "rope_cuda.cuh"
#include "rope_cuda_kernel.cuh"
namespace op::rope::cuda {
struct Descriptor::Opaque {
std::shared_ptr<device::cuda::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::cuda::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
new Opaque{reinterpret_cast<device::cuda::Handle *>(handle)->internal()},
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPE(const RoPEInfo &info,
int block_size,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
cudaStream_t stream) {
auto dimx = uint32_t(info.seqlen),
dimy = uint32_t(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size);
ropeThreadPerItem<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE(TDATA, TINDEX) \
calculateRoPE(_info, \
_opaque->internal->maxThreadsPerBlock(), \
(TDATA *)y, \
(const TDATA *)x, \
(const TINDEX *)pos_ids, \
(const TDATA *)sin_table, \
(const TDATA *)cos_table, \
(cudaStream_t)stream)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(half);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
ROPE_TYPE(double);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope::cuda
#ifndef __INFINIOP_ROPE_CUDA_H__
#define __INFINIOP_ROPE_CUDA_H__
#include "../rope.h"
DESCRIPTOR(cuda)
#endif // __INFINIOP_ROPE_CUDA_H__
#ifndef __INFINIOP_ROPE_CUDA_KERNEL_CUH__
#define __INFINIOP_ROPE_CUDA_KERNEL_CUH__
#include "../../../devices/cuda/cuda_kernel_common.cuh"
template <typename Tdata, typename Tindex, typename Tangle>
INFINIOP_CUDA_KERNEL ropeThreadPerItem(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
auto y_offset = blockIdx.x * y_stride_seqlen + blockIdx.y * y_stride_nhead;
auto x_offset = blockIdx.x * x_stride_seqlen + blockIdx.y * x_stride_nhead;
size_t pos_id = size_t(pos_ids[blockIdx.x]);
auto table_offset = pos_id * table_dim;
for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) {
Tangle sin__ = sin_table[table_offset + i],
cos__ = cos_table[table_offset + i];
if constexpr (std::is_same<Tdata, half>::value) {
auto &y = reinterpret_cast<half2 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const half2 &>(x_[x_offset + 2 * i]);
Tangle y0 = x.x * cos__ - x.y * sin__,
y1 = x.x * sin__ + x.y * cos__;
y = half2(y0, y1);
} else {
Tangle x0 = x_[x_offset + 2 * i],
x1 = x_[x_offset + 2 * i + 1];
y_[y_offset + 2 * i] = Tdata(x0 * cos__ - x1 * sin__);
y_[y_offset + 2 * i + 1] = Tdata(x0 * sin__ + x1 * cos__);
}
}
}
#endif
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/rotary_embedding.h"
#include "infiniop/ops/rope.h"
#ifdef ENABLE_CPU_API
#include "cpu/rope_cpu.h"
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/rope_cuda.cuh"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/rope_ascend.h"
#endif
__C infiniStatus_t infiniopCreateRoPEDescriptor(
infiniopHandle_t handle, infiniopRoPEDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t t, infiniopTensorDescriptor_t pos_ids,
infiniopHandle_t handle,
infiniopRoPEDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table) {
switch (handle->device) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuCreateRoPEDescriptor((CpuHandle_t)handle,
(RoPECpuDescriptor_t *)desc_ptr, t,
pos_ids, sin_table, cos_table);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaCreateRoPEDescriptor((CudaHandle_t)handle,
(RoPECudaDescriptor_t *)desc_ptr, t,
pos_ids, sin_table, cos_table);
}
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::rope::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::rope::NAMESPACE::Descriptor **>(desc_ptr), \
y, \
x, \
pos_ids, \
sin_table, \
cos_table)
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
......@@ -29,12 +46,8 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
pos_ids, sin_table, cos_table);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return ascendCreateRoPEDescriptor((AscendHandle_t)handle,
(RoPEAscendDescriptor_t *)desc_ptr, t,
pos_ids, sin_table, cos_table);
}
#ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
......@@ -51,31 +64,33 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
}
#endif
}
#undef CREATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::rope::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuGetRoPEWorkspaceSize((RoPECpuDescriptor_t)desc, size);
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaGetRoPEWorkspaceSize((RoPECudaDescriptor_t)desc, size);
}
#ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangGetRoPEWorkspaceSize((RoPEBangDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return ascendGetRoPEWorkspaceSize((RoPEAscendDescriptor_t)desc, size);
}
#ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
......@@ -88,26 +103,34 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
}
#endif
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopRoPE(infiniopRoPEDescriptor_t desc,
void *workspace, size_t workspace_size,
void *t, const void *pos_ids,
const void *sin_table, const void *cos_table,
void *stream) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuRoPE((RoPECpuDescriptor_t)desc, workspace, workspace_size, t,
pos_ids, sin_table, cos_table, stream);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaRoPE((RoPECudaDescriptor_t)desc, workspace, workspace_size,
t, pos_ids, sin_table, cos_table, stream);
}
__C infiniStatus_t infiniopRoPE(
infiniopRoPEDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::rope::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, x, pos_ids, sin_table, cos_table, stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
......@@ -115,12 +138,8 @@ __C infiniStatus_t infiniopRoPE(infiniopRoPEDescriptor_t desc,
t, pos_ids, sin_table, cos_table, stream);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return ascendRoPE((RoPEAscendDescriptor_t)desc, workspace,
workspace_size, t, pos_ids, sin_table, cos_table,
stream);
}
#ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
......@@ -135,31 +154,34 @@ __C infiniStatus_t infiniopRoPE(infiniopRoPEDescriptor_t desc,
}
#endif
}
#undef CALCULATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t
infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::rope::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
return cpuDestroyRoPEDescriptor((RoPECpuDescriptor_t)desc);
#ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NV_GPU
case DevNvGpu: {
return cudaDestroyRoPEDescriptor((RoPECudaDescriptor_t)desc);
}
#ifdef ENABLE_CUDA_API
DELETE(INFINI_DEVICE_NVIDIA, cuda);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangDestroyRoPEDescriptor((RoPEBangDescriptor_t)desc);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return ascendDestroyRoPEDescriptor((RoPEAscendDescriptor_t)desc);
}
#ifdef ENABLE_ASCEND_API
DELETE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
......@@ -172,5 +194,8 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
}
#endif
}
#undef DELETE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#ifndef __ROPE_H__
#define __ROPE_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::rope::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
RoPEInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
RoPEInfo info, \
size_t workspace_size_, \
Opaque *opaque, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size_) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc, \
infiniopTensorDescriptor_t pos_desc, \
infiniopTensorDescriptor_t sin_desc, \
infiniopTensorDescriptor_t cos_desc); \
\
infiniStatus_t calculate( \
void *workspace, \
size_t workspace_size, \
void *y, \
const void *x, \
const void *pos_ids, \
const void *sin_table, \
const void *cos_table, \
void *stream) const; \
}; \
}
class RoPEInfo {
private:
RoPEInfo() = default;
public:
infiniDtype_t data_type, pos_type;
size_t seqlen, nhead, dhead, table_len, table_dim;
ptrdiff_t
y_stride_seqlen,
y_stride_nhead,
x_stride_seqlen,
x_stride_nhead;
static utils::Result<RoPEInfo> createRoPEInfo(
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
CHECK_OR_RETURN(
y_desc != nullptr && pos_desc != nullptr && sin_desc != nullptr && cos_desc != nullptr,
INFINI_STATUS_NULL_POINTER);
const infiniDtype_t data_type = y_desc->dtype();
const infiniDtype_t pos_type = pos_desc->dtype();
CHECK_OR_RETURN(data_type == x_desc->dtype() && data_type == sin_desc->dtype() && data_type == cos_desc->dtype(),
INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE_ANY_INT(pos_type);
CHECK_OR_RETURN(y_desc->ndim() == 3
&& x_desc->ndim() == 3
&& pos_desc->ndim() == 1
&& sin_desc->ndim() == 2
&& cos_desc->ndim() == 2,
INFINI_STATUS_BAD_TENSOR_SHAPE);
const auto seqlen = y_desc->dim(0),
nhead = y_desc->dim(1),
dhead = y_desc->dim(2),
table_len = sin_desc->dim(0),
table_dim = sin_desc->dim(1);
CHECK_OR_RETURN(seqlen == x_desc->dim(0)
&& seqlen == pos_desc->dim(0)
&& nhead == x_desc->dim(1) && dhead == x_desc->dim(2)
&& table_len == cos_desc->dim(0) && table_dim == cos_desc->dim(1),
INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(dhead == table_dim * 2, INFINI_STATUS_BAD_TENSOR_SHAPE);
// Last dimension of x and y must be contiguous
CHECK_OR_RETURN(y_desc->stride(2) == 1 && x_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
// sin table and cos table must be totally contiguous
CHECK_OR_RETURN(sin_desc->stride(1) == 1
&& cos_desc->stride(1) == 1
&& sin_desc->stride(0) == ptrdiff_t(table_dim)
&& cos_desc->stride(0) == ptrdiff_t(table_dim),
INFINI_STATUS_BAD_TENSOR_STRIDES);
return utils::Result<RoPEInfo>(RoPEInfo{
data_type,
pos_type,
seqlen,
nhead,
dhead,
table_len,
table_dim,
y_desc->stride(0),
y_desc->stride(1),
x_desc->stride(0),
x_desc->stride(1),
});
}
};
#endif
#include "swiglu_ascend.h"
#include "../../../devices/ascend/common_ascend.h"
namespace op::swiglu::ascend {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(infiniopHandle_t handle, Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
std::vector<infiniopTensorDescriptor_t> input_descs) {
auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle);
auto dtype = c_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
const auto &a_desc = input_descs[0];
const auto &b_desc = input_descs[1];
auto result = SwigluInfo::create(c_desc, a_desc, b_desc);
CHECK_RESULT(result);
SwigluInfo info = result.take();
// https://www.hiascend.com/document/detail/zh/canncommercial/800/apiref/ascendcopapi/atlasascendc_api_07_0777.html
size_t workspace_size = 0;
*desc_ptr = new Descriptor(std::move(info), workspace_size, handle_ascend->device, handle_ascend->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace,
size_t workspace_size,
void *c,
std::vector<const void *> inputs,
void *stream) const {
auto batch = _info.ndim == 2 ? 1 : _info.shape[0];
auto seq_len = _info.ndim == 2 ? _info.shape[0] : _info.shape[1];
auto hidden_size = _info.shape[_info.ndim - 1];
auto stride_batch_c = _info.ndim == 2 ? 1 : _info.c_strides[0];
auto stride_batch_a = _info.ndim == 2 ? 1 : _info.a_strides[0];
auto stride_batch_b = _info.ndim == 2 ? 1 : _info.b_strides[0];
auto stride_seq_c = _info.ndim == 2 ? _info.c_strides[0] : _info.c_strides[1];
auto stride_seq_a = _info.ndim == 2 ? _info.a_strides[0] : _info.a_strides[1];
auto stride_seq_b = _info.ndim == 2 ? _info.b_strides[0] : _info.b_strides[1];
auto status = swiglu_kernel_launch(c, (void *)inputs[0], (void *)inputs[1], _info.dtype, batch, seq_len, hidden_size, stride_batch_c, stride_batch_a, stride_batch_b, stride_seq_c, stride_seq_a, stride_seq_b, stream);
return status;
}
} // namespace op::swiglu::ascend
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment