Commit 86515765 authored by Ziminli's avatar Ziminli
Browse files

issue/428: merge rope_v2 into rope with algorithm selection

parent 15ac0191
......@@ -15,7 +15,6 @@
#include "infiniop/ops/relu.h"
#include "infiniop/ops/rms_norm.h"
#include "infiniop/ops/rope.h"
#include "infiniop/ops/rope_v2.h"
#include "infiniop/ops/softplus.h"
#include "infiniop/ops/sub.h"
#include "infiniop/ops/swiglu.h"
......
......@@ -3,6 +3,13 @@
#include "../operator_descriptor.h"
typedef enum {
INFINIOP_ROPE_ALGO_GPT_J = 0, // GPT-J style RoPE algorithm (Interleave even and odd dimensions)
INFINIOP_ROPE_ALGO_GPT_NEOX = 1, // GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos)
// Count
INFINIOP_ROPE_ALGO_COUNT = 2,
} infiniopRoPEAlgo_t;
typedef struct InfiniopDescriptor *infiniopRoPEDescriptor_t;
__C __export infiniStatus_t infiniopCreateRoPEDescriptor(
......@@ -12,7 +19,8 @@ __C __export infiniStatus_t infiniopCreateRoPEDescriptor(
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table);
infiniopTensorDescriptor_t cos_table,
infiniopRoPEAlgo_t algo);
__C __export infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, size_t *size);
......
#ifndef __INFINIOP_ROPE_V2_API_H__
#define __INFINIOP_ROPE_V2_API_H__
#include "../operator_descriptor.h"
typedef struct InfiniopDescriptor *infiniopRoPEv2Descriptor_t;
__C __export infiniStatus_t infiniopCreateRoPEv2Descriptor(
infiniopHandle_t handle,
infiniopRoPEv2Descriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table);
__C __export infiniStatus_t infiniopGetRoPEv2WorkspaceSize(infiniopRoPEv2Descriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopRoPEv2(
infiniopRoPEv2Descriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
void const *pos_ids,
void const *sin_table,
void const *cos_table,
void *stream);
__C __export infiniStatus_t infiniopDestroyRoPEv2Descriptor(infiniopRoPEv2Descriptor_t desc);
#endif
......@@ -12,11 +12,12 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
infiniopTensorDescriptor_t cos_desc,
infiniopRoPEAlgo_t algo) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo);
CHECK_RESULT(info);
// Create descriptor
......@@ -46,8 +47,8 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
size_t table_offset = pos_id * info.table_dim;
for (size_t i = 0; i < info.table_dim; i++) {
size_t pos0 = 2 * i;
size_t pos1 = 2 * i + 1;
size_t pos0 = info.algo == infiniopRoPEAlgo_t::INFINIOP_ROPE_ALGO_GPT_J ? 2 * i : i;
size_t pos1 = info.algo == infiniopRoPEAlgo_t::INFINIOP_ROPE_ALGO_GPT_J ? 2 * i + 1 : i + info.table_dim;
if constexpr (std::is_same<Tdata, fp16_t>::value || std::is_same<Tdata, bf16_t>::value) {
float x0 = utils::cast<float>(x[x_offset + pos0]),
......
#ifndef __INFINIOP_ROPE_CUDA_KERNEL_CUH__
#define __INFINIOP_ROPE_CUDA_KERNEL_CUH__
template <typename Tdata, typename Tindex, typename Tangle>
template <bool IsGPTJ, typename Tdata, typename Tindex, typename Tangle>
__device__ void ropeThreadPerItemBlock(
Tdata *y_,
const Tdata *x_,
......@@ -22,28 +22,60 @@ __device__ void ropeThreadPerItemBlock(
for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) {
Tangle sin__ = sin_table[table_offset + i],
cos__ = cos_table[table_offset + i];
if constexpr (std::is_same<Tdata, half>::value) {
auto &y = reinterpret_cast<half2 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const half2 &>(x_[x_offset + 2 * i]);
Tangle y0 = x.x * cos__ - x.y * sin__,
y1 = x.x * sin__ + x.y * cos__;
y = half2(y0, y1);
} else if constexpr (std::is_same<Tdata, cuda_bfloat16>::value) {
auto &y = reinterpret_cast<cuda_bfloat162 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const cuda_bfloat162 &>(x_[x_offset + 2 * i]);
Tangle x0 = __low2bfloat16(x);
Tangle x1 = __high2bfloat16(x);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y = __floats2bfloat162_rn(y0, y1);
if constexpr (IsGPTJ) {
if constexpr (std::is_same<Tdata, half>::value) {
auto &y = reinterpret_cast<half2 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const half2 &>(x_[x_offset + 2 * i]);
Tangle y0 = x.x * cos__ - x.y * sin__,
y1 = x.x * sin__ + x.y * cos__;
y = half2(y0, y1);
} else if constexpr (std::is_same<Tdata, cuda_bfloat16>::value) {
auto &y = reinterpret_cast<cuda_bfloat162 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const cuda_bfloat162 &>(x_[x_offset + 2 * i]);
Tangle x0 = __low2bfloat16(x);
Tangle x1 = __high2bfloat16(x);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y = __floats2bfloat162_rn(y0, y1);
} else {
Tangle x0 = x_[x_offset + 2 * i],
x1 = x_[x_offset + 2 * i + 1];
y_[y_offset + 2 * i] = Tdata(x0 * cos__ - x1 * sin__);
y_[y_offset + 2 * i + 1] = Tdata(x0 * sin__ + x1 * cos__);
}
} else {
Tangle x0 = x_[x_offset + 2 * i],
x1 = x_[x_offset + 2 * i + 1];
y_[y_offset + 2 * i] = Tdata(x0 * cos__ - x1 * sin__);
y_[y_offset + 2 * i + 1] = Tdata(x0 * sin__ + x1 * cos__);
size_t pos0 = i;
size_t pos1 = i + table_dim;
if constexpr (std::is_same<Tdata, half>::value) {
Tangle x0 = __half2float(x_[x_offset + pos0]);
Tangle x1 = __half2float(x_[x_offset + pos1]);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y_[y_offset + pos0] = __float2half(y0);
y_[y_offset + pos1] = __float2half(y1);
} else if constexpr (std::is_same<Tdata, cuda_bfloat16>::value) {
Tangle x0 = __bfloat162float(x_[x_offset + pos0]);
Tangle x1 = __bfloat162float(x_[x_offset + pos1]);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y_[y_offset + pos0] = __float2bfloat16(y0);
y_[y_offset + pos1] = __float2bfloat16(y1);
} else {
Tangle x0 = x_[x_offset + pos0];
Tangle x1 = x_[x_offset + pos1];
y_[y_offset + pos0] = x0 * cos__ - x1 * sin__;
y_[y_offset + pos1] = x0 * sin__ + x1 * cos__;
}
}
}
}
......
......@@ -5,7 +5,7 @@
#include "../cuda/kernel.cuh"
template <typename Tdata, typename Tindex, typename Tangle>
template <bool IsGPTJ, typename Tdata, typename Tindex, typename Tangle>
INFINIOP_METAX_KERNEL ropeThreadPerItemKernel(
Tdata *y_,
const Tdata *x_,
......@@ -17,7 +17,7 @@ INFINIOP_METAX_KERNEL ropeThreadPerItemKernel(
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ropeThreadPerItemBlock(
ropeThreadPerItemBlock<IsGPTJ>(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
......@@ -42,11 +42,12 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
infiniopTensorDescriptor_t cos_desc,
infiniopRoPEAlgo_t algo) {
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo);
CHECK_RESULT(info);
// Create descriptor
......@@ -72,10 +73,17 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
auto dimx = uint32_t(info.seqlen),
dimy = uint32_t(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size);
ropeThreadPerItemKernel<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
bool is_gpt_j = info.algo == infiniopRoPEAlgo_t::INFINIOP_ROPE_ALGO_GPT_J;
if (is_gpt_j) {
ropeThreadPerItemKernel<true><<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
} else {
ropeThreadPerItemKernel<false><<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
}
return INFINI_STATUS_SUCCESS;
}
......
......@@ -8,7 +8,7 @@
* which ensuring code alignment across different hardware platforms.
*/
template <typename Tdata, typename Tindex, typename Tangle>
template <bool IsGPTJ, typename Tdata, typename Tindex, typename Tangle>
__device__ void ropeThreadPerItemBlock(
Tdata *y_,
const Tdata *x_,
......@@ -29,40 +29,72 @@ __device__ void ropeThreadPerItemBlock(
for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) {
Tangle sin__ = sin_table[table_offset + i],
cos__ = cos_table[table_offset + i];
if constexpr (std::is_same<Tdata, half>::value) {
auto &y = reinterpret_cast<half2 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const half2 &>(x_[x_offset + 2 * i]);
Tangle y0 = x.x * cos__ - x.y * sin__,
y1 = x.x * sin__ + x.y * cos__;
y = half2(y0, y1);
} else if constexpr (std::is_same<Tdata, cuda_bfloat16>::value) {
auto &y = reinterpret_cast<cuda_bfloat162 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const cuda_bfloat162 &>(x_[x_offset + 2 * i]);
/*
* The original code used CUDA-specific functions (__low2bfloat16, __high2bfloat16)
* to extract bfloat16 values from a packed variable.
*
* This code has been modified for the MUSA platform, which does not support
* these CUDA built-in functions. Instead, MUSA provides a different set of
* built-in functions (`__low2float`, `__high2float`) that directly convert
* the bfloat16 values to float.
*
* This change ensures cross-platform compatibility and resolves compilation errors.
*/
Tangle x0 = __low2float(x);
Tangle x1 = __high2float(x);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y = __floats2bfloat162_rn(y0, y1);
if constexpr (IsGPTJ) {
if constexpr (std::is_same<Tdata, half>::value) {
auto &y = reinterpret_cast<half2 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const half2 &>(x_[x_offset + 2 * i]);
Tangle y0 = x.x * cos__ - x.y * sin__,
y1 = x.x * sin__ + x.y * cos__;
y = half2(y0, y1);
} else if constexpr (std::is_same<Tdata, cuda_bfloat16>::value) {
auto &y = reinterpret_cast<cuda_bfloat162 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const cuda_bfloat162 &>(x_[x_offset + 2 * i]);
/*
* The original code used CUDA-specific functions (__low2bfloat16, __high2bfloat16)
* to extract bfloat16 values from a packed variable.
*
* This code has been modified for the MUSA platform, which does not support
* these CUDA built-in functions. Instead, MUSA provides a different set of
* built-in functions (`__low2float`, `__high2float`) that directly convert
* the bfloat16 values to float.
*
* This change ensures cross-platform compatibility and resolves compilation errors.
*/
Tangle x0 = __low2float(x);
Tangle x1 = __high2float(x);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y = __floats2bfloat162_rn(y0, y1);
} else {
Tangle x0 = x_[x_offset + 2 * i],
x1 = x_[x_offset + 2 * i + 1];
y_[y_offset + 2 * i] = Tdata(x0 * cos__ - x1 * sin__);
y_[y_offset + 2 * i + 1] = Tdata(x0 * sin__ + x1 * cos__);
}
} else {
Tangle x0 = x_[x_offset + 2 * i],
x1 = x_[x_offset + 2 * i + 1];
y_[y_offset + 2 * i] = Tdata(x0 * cos__ - x1 * sin__);
y_[y_offset + 2 * i + 1] = Tdata(x0 * sin__ + x1 * cos__);
size_t pos0 = i;
size_t pos1 = i + table_dim;
if constexpr (std::is_same<Tdata, half>::value) {
Tangle x0 = __half2float(x_[x_offset + pos0]);
Tangle x1 = __half2float(x_[x_offset + pos1]);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y_[y_offset + pos0] = __float2half(y0);
y_[y_offset + pos1] = __float2half(y1);
} else if constexpr (std::is_same<Tdata, cuda_bfloat16>::value) {
Tangle x0 = __bfloat162float(x_[x_offset + pos0]);
Tangle x1 = __bfloat162float(x_[x_offset + pos1]);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y_[y_offset + pos0] = __float2bfloat16(y0);
y_[y_offset + pos1] = __float2bfloat16(y1);
} else {
Tangle x0 = x_[x_offset + pos0];
Tangle x1 = x_[x_offset + pos1];
y_[y_offset + pos0] = x0 * cos__ - x1 * sin__;
y_[y_offset + pos1] = x0 * sin__ + x1 * cos__;
}
}
}
}
......
......@@ -5,7 +5,7 @@
#include "rope_kernel_moore.h"
template <typename Tdata, typename Tindex, typename Tangle>
template <bool IsGPTJ, typename Tdata, typename Tindex, typename Tangle>
INFINIOP_MOORE_KERNEL ropeThreadPerItemKernel(
Tdata *y_,
const Tdata *x_,
......@@ -17,7 +17,7 @@ INFINIOP_MOORE_KERNEL ropeThreadPerItemKernel(
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ropeThreadPerItemBlock(
ropeThreadPerItemBlock<IsGPTJ>(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
......@@ -42,11 +42,12 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
infiniopTensorDescriptor_t cos_desc,
infiniopRoPEAlgo_t algo) {
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo);
CHECK_RESULT(info);
// Create descriptor
......@@ -72,10 +73,17 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
auto dimx = uint32_t(info.seqlen),
dimy = uint32_t(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size);
ropeThreadPerItemKernel<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
bool is_gpt_j = info.algo == infiniopRoPEAlgo_t::INFINIOP_ROPE_ALGO_GPT_J;
if (is_gpt_j) {
ropeThreadPerItemKernel<true><<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
} else {
ropeThreadPerItemKernel<false><<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
}
return INFINI_STATUS_SUCCESS;
}
......
......@@ -5,7 +5,7 @@
#include "../cuda/kernel.cuh"
template <typename Tdata, typename Tindex, typename Tangle>
template <bool IsGPTJ, typename Tdata, typename Tindex, typename Tangle>
INFINIOP_CUDA_KERNEL ropeThreadPerItemKernel(
Tdata *y_,
const Tdata *x_,
......@@ -17,7 +17,7 @@ INFINIOP_CUDA_KERNEL ropeThreadPerItemKernel(
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ropeThreadPerItemBlock(
ropeThreadPerItemBlock<IsGPTJ>(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
......@@ -42,11 +42,12 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
infiniopTensorDescriptor_t cos_desc,
infiniopRoPEAlgo_t algo) {
auto handle = reinterpret_cast<device::nvidia::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo);
CHECK_RESULT(info);
// Create descriptor
......@@ -72,10 +73,17 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
auto dimx = uint32_t(info.seqlen),
dimy = uint32_t(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size);
ropeThreadPerItemKernel<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
bool is_gpt_j = info.algo == infiniopRoPEAlgo_t::INFINIOP_ROPE_ALGO_GPT_J;
if (is_gpt_j) {
ropeThreadPerItemKernel<true><<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
} else {
ropeThreadPerItemKernel<false><<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
}
return INFINI_STATUS_SUCCESS;
}
......
......@@ -31,7 +31,8 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table) {
infiniopTensorDescriptor_t cos_table,
infiniopRoPEAlgo_t algo) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
......@@ -42,7 +43,8 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
x, \
pos_ids, \
sin_table, \
cos_table)
cos_table, \
algo)
switch (handle->device) {
#ifdef ENABLE_CPU_API
......
......@@ -4,6 +4,7 @@
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include "infiniop/ops/rope.h"
#define DESCRIPTOR(NAMESPACE) \
\
......@@ -37,7 +38,8 @@
infiniopTensorDescriptor_t x_desc, \
infiniopTensorDescriptor_t pos_desc, \
infiniopTensorDescriptor_t sin_desc, \
infiniopTensorDescriptor_t cos_desc); \
infiniopTensorDescriptor_t cos_desc, \
infiniopRoPEAlgo_t algo); \
\
infiniStatus_t calculate( \
void *workspace, \
......@@ -63,15 +65,18 @@ public:
y_stride_nhead,
x_stride_seqlen,
x_stride_nhead;
infiniopRoPEAlgo_t algo;
static utils::Result<RoPEInfo> createRoPEInfo(
static utils::Result<RoPEInfo>
createRoPEInfo(
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
infiniopTensorDescriptor_t cos_desc,
infiniopRoPEAlgo_t algo) {
CHECK_OR_RETURN(
y_desc != nullptr && pos_desc != nullptr && sin_desc != nullptr && cos_desc != nullptr,
y_desc != nullptr && pos_desc != nullptr && sin_desc != nullptr && cos_desc != nullptr && algo < infiniopRoPEAlgo_t::INFINIOP_ROPE_ALGO_COUNT,
INFINI_STATUS_NULL_POINTER);
const infiniDtype_t data_type = y_desc->dtype();
......@@ -118,6 +123,7 @@ public:
y_desc->stride(1),
x_desc->stride(0),
x_desc->stride(1),
algo,
});
}
};
......
#include "rope_ascend.h"
#include "../../../devices/ascend/common_ascend.h"
namespace op::rope::ascend {
Descriptor::~Descriptor()
= default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle_ascned = reinterpret_cast<device::ascend::Handle *>(handle);
auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(result);
size_t workspace_size = 0;
*desc_ptr = new Descriptor(std::move(result.take()), workspace_size, nullptr, handle_ascned->device, handle_ascned->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
CHECK_DTYPE(_info.data_type, INFINI_DTYPE_F32, INFINI_DTYPE_F16);
auto data_type = _info.data_type;
auto pos_type = _info.pos_type;
auto seq_len = _info.seqlen;
auto nhead = _info.nhead;
auto dhead = _info.dhead;
auto y_stride_seqlen = _info.y_stride_seqlen;
auto y_stride_nhead = _info.y_stride_nhead;
auto x_stride_seqlen = _info.x_stride_seqlen;
auto x_stride_nhead = _info.x_stride_nhead;
return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, pos_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream);
}
} // namespace op::rope::ascend
#ifndef __ACLNN_ROPE_H__
#define __ACLNN_ROPE_H__
#include "../rope.h"
extern "C" infiniStatus_t rope_kernel_launch(
void *y,
void *x,
void *pos,
void *sin,
void *cos,
size_t seq_len,
size_t nhead,
size_t dhead,
infiniDtype_t data_type,
infiniDtype_t pos_type,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead,
void *stream);
DESCRIPTOR(ascend)
#endif // __ACLNN_ROPE_H__
#include "../../../devices/ascend/ascend_kernel_common.h"
using namespace AscendC;
template <typename T, typename U>
class RoPEKernel {
public:
__aicore__ inline RoPEKernel() {}
// Init op
// pos position vector
// x input tensor
// y output tensor
// tensor shape [nt, nh, dh]
// make block_num = nh, tile_len = dh
__aicore__ inline void init(GM_ADDR y,
GM_ADDR x,
GM_ADDR pos,
GM_ADDR sin,
GM_ADDR cos,
size_t dh,
ptrdiff_t st_ynt,
ptrdiff_t st_ynh,
ptrdiff_t st_xnt,
ptrdiff_t st_xnh);
__aicore__ inline void process(size_t seq_len);
private:
// Copy a tile into UB
__aicore__ inline void copyIn(size_t i);
__aicore__ inline void compute(size_t i);
__aicore__ inline void copyOut(size_t i);
private:
TPipe pipe;
TQue<QuePosition::VECIN, BUFFER_NUM> _in_que;
TQue<QuePosition::VECIN, BUFFER_NUM> _sin_que;
TQue<QuePosition::VECIN, BUFFER_NUM> _cos_que;
TQue<QuePosition::VECOUT, BUFFER_NUM> _out_que;
TBuf<TPosition::VECCALC> _tmp_odd_buf;
TBuf<TPosition::VECCALC> _tmp_even_buf;
TBuf<TPosition::VECCALC> _tmp_odd_buf1;
TBuf<TPosition::VECCALC> _tmp_odd_buf2;
TBuf<TPosition::VECCALC> _tmp_even_buf1;
TBuf<TPosition::VECCALC> _tmp_even_buf2;
GlobalTensor<T> _x_gm, _y_gm;
GlobalTensor<U> _p_gm;
GlobalTensor<T> _sin_gm;
GlobalTensor<T> _cos_gm;
size_t _block_idx;
size_t _tile_len;
size_t _copy_len;
size_t _half_copy_len;
// stridey[_st_ynt, _st_ynh, 1]
ptrdiff_t _st_ynt;
ptrdiff_t _st_ynh;
// stridex[_st_xnt, _st_xnh, 1]
ptrdiff_t _st_xnt;
ptrdiff_t _st_xnh;
};
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::init(GM_ADDR y,
GM_ADDR x,
GM_ADDR pos,
GM_ADDR sin,
GM_ADDR cos,
size_t dh,
ptrdiff_t st_ynt,
ptrdiff_t st_ynh,
ptrdiff_t st_xnt,
ptrdiff_t st_xnh) {
this->_tile_len = dh;
this->_st_ynt = st_ynt;
this->_st_ynh = st_ynh;
this->_st_xnt = st_xnt;
this->_st_xnh = st_xnh;
_copy_len = alignTileLen<T>(dh, BYTE_ALIGN);
_half_copy_len = alignTileLen<T>(dh, BYTE_ALIGN);
_block_idx = GetBlockIdx();
// Init global buffer
_x_gm.SetGlobalBuffer((__gm__ T *)x);
_p_gm.SetGlobalBuffer((__gm__ U *)pos);
_sin_gm.SetGlobalBuffer((__gm__ T *)sin);
_cos_gm.SetGlobalBuffer((__gm__ T *)cos);
_y_gm.SetGlobalBuffer((__gm__ T *)y);
// Init Queue buffer
pipe.InitBuffer(_in_que, BUFFER_NUM, _copy_len * sizeof(T));
pipe.InitBuffer(_out_que, BUFFER_NUM, _tile_len * sizeof(T));
pipe.InitBuffer(_sin_que, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(_cos_que, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(_tmp_odd_buf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_even_buf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_odd_buf1, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_odd_buf2, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_even_buf1, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_even_buf2, _tile_len / 2 * sizeof(T));
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::copyIn(size_t i) {
LocalTensor<T> input_ub = _in_que.AllocTensor<T>();
LocalTensor<T> sin_ub = _sin_que.AllocTensor<T>();
LocalTensor<T> cos_ub = _cos_que.AllocTensor<T>();
// Get idx of current tile in total input
auto idx = i * _st_xnt + _block_idx * _st_xnh;
// Copy tile current tile into UB
DataCopy(input_ub, _x_gm[idx], _copy_len);
// Copy sin cos tile
auto pos_idx = _p_gm(i);
DataCopy(sin_ub, _sin_gm[pos_idx * _tile_len / 2], _half_copy_len);
DataCopy(cos_ub, _cos_gm[pos_idx * _tile_len / 2], _half_copy_len);
// Push in operands
_in_que.EnQue(input_ub);
_sin_que.EnQue(sin_ub);
_cos_que.EnQue(cos_ub);
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::compute(size_t i) {
LocalTensor<T> input_ub = _in_que.DeQue<T>();
LocalTensor<T> sin_ub = _sin_que.DeQue<T>();
LocalTensor<T> cos_ub = _cos_que.DeQue<T>();
LocalTensor<T> output_ub = _out_que.AllocTensor<T>();
LocalTensor<T> tmp_odd = _tmp_odd_buf.Get<T>();
LocalTensor<T> tmp_even = _tmp_even_buf.Get<T>();
LocalTensor<T> tmp_odd1 = _tmp_odd_buf1.Get<T>();
LocalTensor<T> tmp_odd2 = _tmp_odd_buf2.Get<T>();
LocalTensor<T> tmp_even1 = _tmp_even_buf1.Get<T>();
LocalTensor<T> tmp_even2 = _tmp_even_buf2.Get<T>();
// separate odd and even bit elements
uint64_t rsvdCnt = 0;
GatherMaskParams gMaskParams = {
1,
static_cast<uint16_t>((_tile_len * sizeof(T) + 255) / 256), // no more than 256(<=255)
8,
8,
};
GatherMask<T>(tmp_odd, input_ub, 1, false, 0, gMaskParams, rsvdCnt);
GatherMask<T>(tmp_even, input_ub, 2, false, 0, gMaskParams, rsvdCnt);
PipeBarrier<PIPE_V>();
// compute odd bit elements
// y_odd = x_odd * cos - x_even * sin
Mul<T>(tmp_odd1, tmp_odd, cos_ub, _tile_len / 2);
Mul<T>(tmp_odd2, tmp_even, sin_ub, _tile_len / 2);
PipeBarrier<PIPE_V>();
Sub<T>(tmp_odd1, tmp_odd1, tmp_odd2, _tile_len / 2);
// compute even bit elements
// y_even = x_odd * sin + x_even * cos
Mul<T>(tmp_even1, tmp_odd, sin_ub, _tile_len / 2);
Mul<T>(tmp_even2, tmp_even, cos_ub, _tile_len / 2);
PipeBarrier<PIPE_V>();
Add<T>(tmp_even1, tmp_even1, tmp_even2, _tile_len / 2);
// combine odd and even bit elements
for (uint32_t j = 0; j < _tile_len / 2; j += 1) {
output_ub(j * 2) = tmp_odd1(j);
output_ub(j * 2 + 1) = tmp_even1(j);
}
_out_que.EnQue<T>(output_ub);
_in_que.FreeTensor(input_ub);
_sin_que.FreeTensor(sin_ub);
_cos_que.FreeTensor(cos_ub);
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::copyOut(size_t i) {
LocalTensor<T> output_ub = _out_que.DeQue<T>();
auto idy = i * _st_ynt + _block_idx * _st_ynh;
DataCopyExtParams params = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
DataCopyPad(_y_gm[idy], output_ub, params);
_out_que.FreeTensor(output_ub);
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::process(size_t seq_len) {
for (size_t i = 0; i < seq_len; ++i) {
copyIn(i);
compute(i);
copyOut(i);
}
}
#define ROPE_KERNEL_INIT_ARGS y, x, pos, sin, cos, dhead, \
y_stride_seqlen, y_stride_nhead, \
x_stride_seqlen, x_stride_nhead
#define CASE_POSTYPE(POS_TYPE_ENUM, TYPE, POS_T) \
case POS_TYPE_ENUM: { \
RoPEKernel<TYPE, POS_T> op; \
op.init(ROPE_KERNEL_INIT_ARGS); \
op.process(seq_len); \
break; \
}
#define ROPE_KERNEL(TYPE, POSTYPE) \
switch (POSTYPE) { \
CASE_POSTYPE(INFINI_DTYPE_I8, TYPE, int8_t) \
CASE_POSTYPE(INFINI_DTYPE_I16, TYPE, int16_t) \
CASE_POSTYPE(INFINI_DTYPE_I32, TYPE, int32_t) \
CASE_POSTYPE(INFINI_DTYPE_I64, TYPE, int64_t) \
CASE_POSTYPE(INFINI_DTYPE_U8, TYPE, uint8_t) \
CASE_POSTYPE(INFINI_DTYPE_U16, TYPE, uint16_t) \
CASE_POSTYPE(INFINI_DTYPE_U32, TYPE, uint32_t) \
CASE_POSTYPE(INFINI_DTYPE_U64, TYPE, uint64_t) \
default: \
break; \
}
#define DEFINE_ROPE_KERNEL(KERNEL_NAME, TYPE) \
__global__ __aicore__ void KERNEL_NAME(GM_ADDR y, \
GM_ADDR x, \
GM_ADDR pos, \
GM_ADDR sin, \
GM_ADDR cos, \
size_t seq_len, \
size_t dhead, \
ptrdiff_t y_stride_seqlen, \
ptrdiff_t y_stride_nhead, \
ptrdiff_t x_stride_seqlen, \
ptrdiff_t x_stride_nhead, \
int32_t pos_type) { \
ROPE_KERNEL(TYPE, pos_type) \
}
DEFINE_ROPE_KERNEL(rope_kernel_float, float)
DEFINE_ROPE_KERNEL(rope_kernel_half, half)
#undef DEFINE_ROPE_KERNEL
#undef ROPE_KERNEL
#undef CASE_POSTYPE
#undef ROPE_KERNEL_INIT_ARGS
extern "C" infiniStatus_t rope_kernel_launch(
void *y,
void *x,
void *pos,
void *sin,
void *cos,
size_t seq_len,
size_t nhead,
size_t dhead,
infiniDtype_t dtype,
infiniDtype_t pos_type,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead,
void *stream) {
#define LAUNCH_ROPE_KERNEL(DTYPE_ENUM, KERNEL_NAME) \
case DTYPE_ENUM: \
KERNEL_NAME<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, \
seq_len, \
dhead, \
y_stride_seqlen, \
y_stride_nhead, \
x_stride_seqlen, \
x_stride_nhead, \
pos_type); \
return INFINI_STATUS_SUCCESS;
switch (dtype) {
LAUNCH_ROPE_KERNEL(INFINI_DTYPE_F16, rope_kernel_half)
LAUNCH_ROPE_KERNEL(INFINI_DTYPE_F32, rope_kernel_float)
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
#ifndef __INFINIOP_ROPE_BANG_H__
#define __INFINIOP_ROPE_BANG_H__
#include "../rope.h"
DESCRIPTOR(bang)
#endif // __INFINIOP_ROPE_BANG_H__
#include "../../../devices/bang/common_bang.h"
#include "rope_bang.h"
#include "rope_bang_kernel.mlu"
namespace op::rope::bang {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::bang::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
nullptr,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPE(const RoPEInfo &info,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
cnrtQueue_t queue) {
auto dimx = uint32_t(info.seqlen);
auto dimy = uint32_t(info.nhead);
auto table_dim = uint32_t(info.table_dim);
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
// Configure kernel launch parameters
k_dim.x = 4;
k_dim.y = 1;
k_dim.z = 1;
k_type = CNRT_FUNC_TYPE_UNION1;
// Launch kernel
ropeKernel<<<k_dim, k_type, queue>>>(
y, x, pos_ids, sin_table, cos_table,
dimx, dimy, table_dim,
info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_seqlen, info.x_stride_nhead);
cnrtQueueSync(queue);
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE(TDATA, TINDEX) \
calculateRoPE(_info, \
(TDATA *)y, \
(const TDATA *)x, \
(const TINDEX *)pos_ids, \
(const TDATA *)sin_table, \
(const TDATA *)cos_table, \
(cnrtQueue_t)stream)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(half);
case INFINI_DTYPE_BF16:
ROPE_TYPE(bfloat16_t);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope::bang
#include "../../../devices/bang/common_bang.h"
__nram__ char nram_buffer[NRAM_MAX_SIZE];
template <typename Tdata>
__mlu_device__ void calculateRope(
Tdata *out, const Tdata *in,
const Tdata *sin_table, const Tdata *cos_table,
Tdata *sin_cache, Tdata *cos_cache,
Tdata *x1sin, Tdata *x0cos, Tdata *x0sin, Tdata *x1cos,
Tdata *input_0, Tdata *input_1, Tdata *input_cache,
int theta_index, int out_index, int in_index,
int chunk_size, int half_chunk_size, int data_segsize,
int src_load_stride, int dst_load_stride, int src_write_stride, int dst_write_stride) {
// Load sin/cos data
__memcpy(sin_cache, sin_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
__memcpy(cos_cache, cos_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
// Load input data
__memcpy(input_cache, in + in_index, chunk_size * sizeof(Tdata), GDRAM2NRAM);
// Split input into even and odd positions
__memcpy(input_0, input_cache, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
__memcpy(input_1, input_cache + 1, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
// Compute even positions: y0 = x0 * cos - x1 * sin and y1 = x0 * sin + x1 * cos
__bang_mul(x0cos, input_0, cos_cache, half_chunk_size);
__bang_mul(x1sin, input_1, sin_cache, half_chunk_size);
__bang_mul(x0sin, input_0, sin_cache, half_chunk_size);
__bang_mul(x1cos, input_1, cos_cache, half_chunk_size);
__bang_sub(input_0, x0cos, x1sin, half_chunk_size);
__bang_add(input_1, x0sin, x1cos, half_chunk_size);
// Interleave results back into output buffer
__memcpy(input_cache, input_0, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
__memcpy(input_cache + 1, input_1, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
// Write back results
__memcpy(out + out_index, input_cache, chunk_size * sizeof(Tdata), NRAM2GDRAM);
}
template <typename Tdata, typename Tindex>
__mlu_global__ void ropeKernel(
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
uint32_t seqlen,
uint32_t nhead,
uint32_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
// Calculate available NRAM space after alignment
const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * 9); // 9 buffers need alignment
const size_t max_chunk_elements = nram_usable / (9 * sizeof(Tdata));
// Key variables that determine execution path
const bool use_pos_ids_buffer = (seqlen * sizeof(Tindex) <= (nram_usable / 2));
const int half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
// Common stride configurations
const int data_segsize = sizeof(Tdata);
const int src_load_stride = 2 * sizeof(Tdata);
const int dst_load_stride = 1 * sizeof(Tdata);
const int src_write_stride = 1 * sizeof(Tdata);
const int dst_write_stride = 2 * sizeof(Tdata);
// Task distribution
const int batch_volume = seqlen * nhead;
const int remaining_tasks = batch_volume % taskDim;
const int base_tasks_per_core = batch_volume / taskDim;
const int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
const int task_start_idx = (taskId < remaining_tasks ? taskId * base_tasks_per_core + taskId : taskId * base_tasks_per_core + remaining_tasks);
// NRAM buffer allocation with proper alignment
char *aligned_nram = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
// Setup position IDs if they fit in NRAM
Tindex *srcP = nullptr;
if (use_pos_ids_buffer) {
srcP = (Tindex *)aligned_nram;
__memcpy(srcP, pos_ids, seqlen * sizeof(Tindex), GDRAM2NRAM);
aligned_nram = (char *)(((size_t)srcP + seqlen * sizeof(Tindex) + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
}
// Main processing buffers (pointers will be set per chunk)
Tdata *sin_cache = nullptr;
Tdata *cos_cache = nullptr;
Tdata *x1sin = nullptr;
Tdata *x0cos = nullptr;
Tdata *x0sin = nullptr;
Tdata *x1cos = nullptr;
Tdata *input_0 = nullptr;
Tdata *input_1 = nullptr;
Tdata *input_cache = nullptr;
// Main processing loop
for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
// Calculate output and input indices
int seq_idx = i / nhead;
int head_idx = i % nhead;
// Output indices (y)
int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
// Input indices (x)
int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
// Get position index
Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
int rot_offset = pos_idx * table_dim;
// Process in chunks that fit in NRAM
int processed = 0;
while (processed < table_dim) {
// Calculate current chunk size
int current_half_chunk = std::min<uint32_t>(half_chunk_size, table_dim - processed);
int current_chunk_size = 2 * current_half_chunk;
int theta_offset = rot_offset + processed;
int dst_offset = out_offset + processed * 2;
int src_offset = in_offset + processed * 2;
// Set up NRAM buffers for this chunk
char *chunk_base = aligned_nram;
sin_cache = (Tdata *)chunk_base;
cos_cache = sin_cache + current_half_chunk;
x1sin = cos_cache + current_half_chunk;
x0cos = x1sin + current_half_chunk;
x0sin = x0cos + current_half_chunk;
x1cos = x0sin + current_half_chunk;
input_0 = x1cos + current_half_chunk;
input_1 = input_0 + current_half_chunk;
input_cache = input_1 + current_half_chunk;
calculateRope<Tdata>(
y, x, sin_table, cos_table,
sin_cache, cos_cache, x1sin, x0cos, x0sin, x1cos,
input_0, input_1, input_cache,
theta_offset, dst_offset, src_offset,
current_chunk_size, current_half_chunk,
data_segsize,
src_load_stride, dst_load_stride, src_write_stride, dst_write_stride);
processed += current_half_chunk;
}
}
}
#include "rope_v2_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
namespace op::rope_v2::cpu {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto info = RoPEv2Info::createRoPEv2Info(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
nullptr,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPEv2(const RoPEv2Info &info,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table) {
#pragma omp parallel for
for (ptrdiff_t h = 0; h < ptrdiff_t(info.nhead); h++) {
for (size_t tok = 0; tok < info.seqlen; tok++) {
size_t x_offset = tok * info.x_stride_seqlen + h * info.x_stride_nhead;
size_t y_offset = tok * info.y_stride_seqlen + h * info.y_stride_nhead;
size_t pos_id = size_t(pos_ids[tok]);
size_t table_offset = pos_id * info.table_dim;
size_t half_dim = info.table_dim; // head_dim = 2 * half_dim
for (size_t i = 0; i < info.table_dim; i++) {
// Pair elements from first half and second half
size_t pos0 = i;
size_t pos1 = i + half_dim;
if constexpr (std::is_same<Tdata, fp16_t>::value || std::is_same<Tdata, bf16_t>::value) {
float x0 = utils::cast<float>(x[x_offset + pos0]),
x1 = utils::cast<float>(x[x_offset + pos1]),
sin__ = utils::cast<float>(sin_table[table_offset + i]),
cos__ = utils::cast<float>(cos_table[table_offset + i]);
y[y_offset + pos0] = utils::cast<Tdata>(x0 * cos__ - x1 * sin__);
y[y_offset + pos1] = utils::cast<Tdata>(x0 * sin__ + x1 * cos__);
} else {
Tdata x0 = x[x_offset + pos0],
x1 = x[x_offset + pos1],
sin__ = sin_table[table_offset + i],
cos__ = cos_table[table_offset + i];
y[y_offset + pos0] = x0 * cos__ - x1 * sin__;
y[y_offset + pos1] = x0 * sin__ + x1 * cos__;
}
}
}
}
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE_V2(TDATA, TINDEX) \
calculateRoPEv2(_info, (TDATA *)y, (const TDATA *)x, (const TINDEX *)pos_ids, (const TDATA *)sin_table, (const TDATA *)cos_table)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE_V2(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE_V2(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE_V2(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE_V2(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE_V2(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE_V2(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE_V2(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE_V2(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(fp16_t);
case INFINI_DTYPE_BF16:
ROPE_TYPE(bf16_t);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
ROPE_TYPE(double);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope_v2::cpu
#ifndef __INFINIOP_ROPE_V2_CPU_H__
#define __INFINIOP_ROPE_V2_CPU_H__
#include "../rope_v2.h"
DESCRIPTOR(cpu)
#endif // __INFINIOP_ROPE_V2_CPU_H__
#ifndef __INFINIOP_ROPE_V2_CUDA_KERNEL_CUH__
#define __INFINIOP_ROPE_V2_CUDA_KERNEL_CUH__
template <typename Tdata, typename Tindex, typename Tangle>
__device__ void ropeThreadPerItemBlock(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
auto y_offset = blockIdx.x * y_stride_seqlen + blockIdx.y * y_stride_nhead;
auto x_offset = blockIdx.x * x_stride_seqlen + blockIdx.y * x_stride_nhead;
size_t pos_id = size_t(pos_ids[blockIdx.x]);
auto table_offset = pos_id * table_dim;
const size_t half_dim = table_dim; // Head dimension = 2 * table_dim
for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) {
Tangle sin__ = sin_table[table_offset + i];
Tangle cos__ = cos_table[table_offset + i];
// Calculate positions in first and second halves
size_t pos0 = i;
size_t pos1 = i + half_dim;
if constexpr (std::is_same<Tdata, half>::value) {
Tangle x0 = __half2float(x_[x_offset + pos0]);
Tangle x1 = __half2float(x_[x_offset + pos1]);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y_[y_offset + pos0] = __float2half(y0);
y_[y_offset + pos1] = __float2half(y1);
} else if constexpr (std::is_same<Tdata, cuda_bfloat16>::value) {
Tangle x0 = __bfloat162float(x_[x_offset + pos0]);
Tangle x1 = __bfloat162float(x_[x_offset + pos1]);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y_[y_offset + pos0] = __float2bfloat16(y0);
y_[y_offset + pos1] = __float2bfloat16(y1);
} else {
Tangle x0 = x_[x_offset + pos0];
Tangle x1 = x_[x_offset + pos1];
y_[y_offset + pos0] = x0 * cos__ - x1 * sin__;
y_[y_offset + pos1] = x0 * sin__ + x1 * cos__;
}
}
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment