Unverified Commit f9d16628 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #429 from InfiniTensor/issue/428_merge_rope_and_rope_v2

Issue/428: Merge `rope_v2` into `rope`
parents 15ac0191 9f0ae734
#include "../../../devices/bang/common_bang.h"
__nram__ char nram_buffer[NRAM_MAX_SIZE];
template <typename Tdata>
__mlu_device__ void calculateRope(
Tdata *out, const Tdata *in,
const Tdata *sin_table, const Tdata *cos_table,
Tdata *sin_cache, Tdata *cos_cache,
Tdata *x1sin, Tdata *x0cos, Tdata *x0sin, Tdata *x1cos,
Tdata *input_0, Tdata *input_1, Tdata *input_cache,
int theta_index, int out_index, int in_index,
int chunk_size, int half_chunk_size, int data_segsize,
int src_load_stride, int dst_load_stride, int src_write_stride, int dst_write_stride) {
// Load sin/cos data
__memcpy(sin_cache, sin_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
__memcpy(cos_cache, cos_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
// Load input data
__memcpy(input_cache, in + in_index, chunk_size * sizeof(Tdata), GDRAM2NRAM);
// Split input into even and odd positions
__memcpy(input_0, input_cache, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
__memcpy(input_1, input_cache + 1, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
// Compute even positions: y0 = x0 * cos - x1 * sin and y1 = x0 * sin + x1 * cos
__bang_mul(x0cos, input_0, cos_cache, half_chunk_size);
__bang_mul(x1sin, input_1, sin_cache, half_chunk_size);
__bang_mul(x0sin, input_0, sin_cache, half_chunk_size);
__bang_mul(x1cos, input_1, cos_cache, half_chunk_size);
__bang_sub(input_0, x0cos, x1sin, half_chunk_size);
__bang_add(input_1, x0sin, x1cos, half_chunk_size);
// Interleave results back into output buffer
__memcpy(input_cache, input_0, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
__memcpy(input_cache + 1, input_1, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
// Write back results
__memcpy(out + out_index, input_cache, chunk_size * sizeof(Tdata), NRAM2GDRAM);
}
template <typename Tdata, typename Tindex>
__mlu_global__ void ropeKernel(
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
uint32_t seqlen,
uint32_t nhead,
uint32_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
// Calculate available NRAM space after alignment
const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * 9); // 9 buffers need alignment
const size_t max_chunk_elements = nram_usable / (9 * sizeof(Tdata));
// Key variables that determine execution path
const bool use_pos_ids_buffer = (seqlen * sizeof(Tindex) <= (nram_usable / 2));
const int half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
// Common stride configurations
const int data_segsize = sizeof(Tdata);
const int src_load_stride = 2 * sizeof(Tdata);
const int dst_load_stride = 1 * sizeof(Tdata);
const int src_write_stride = 1 * sizeof(Tdata);
const int dst_write_stride = 2 * sizeof(Tdata);
// Task distribution
const int batch_volume = seqlen * nhead;
const int remaining_tasks = batch_volume % taskDim;
const int base_tasks_per_core = batch_volume / taskDim;
const int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
const int task_start_idx = (taskId < remaining_tasks ? taskId * base_tasks_per_core + taskId : taskId * base_tasks_per_core + remaining_tasks);
// NRAM buffer allocation with proper alignment
char *aligned_nram = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
// Setup position IDs if they fit in NRAM
Tindex *srcP = nullptr;
if (use_pos_ids_buffer) {
srcP = (Tindex *)aligned_nram;
__memcpy(srcP, pos_ids, seqlen * sizeof(Tindex), GDRAM2NRAM);
aligned_nram = (char *)(((size_t)srcP + seqlen * sizeof(Tindex) + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
}
// Main processing buffers (pointers will be set per chunk)
Tdata *sin_cache = nullptr;
Tdata *cos_cache = nullptr;
Tdata *x1sin = nullptr;
Tdata *x0cos = nullptr;
Tdata *x0sin = nullptr;
Tdata *x1cos = nullptr;
Tdata *input_0 = nullptr;
Tdata *input_1 = nullptr;
Tdata *input_cache = nullptr;
// Main processing loop
for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
// Calculate output and input indices
int seq_idx = i / nhead;
int head_idx = i % nhead;
// Output indices (y)
int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
// Input indices (x)
int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
// Get position index
Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
int rot_offset = pos_idx * table_dim;
// Process in chunks that fit in NRAM
int processed = 0;
while (processed < table_dim) {
// Calculate current chunk size
int current_half_chunk = std::min<uint32_t>(half_chunk_size, table_dim - processed);
int current_chunk_size = 2 * current_half_chunk;
int theta_offset = rot_offset + processed;
int dst_offset = out_offset + processed * 2;
int src_offset = in_offset + processed * 2;
// Set up NRAM buffers for this chunk
char *chunk_base = aligned_nram;
sin_cache = (Tdata *)chunk_base;
cos_cache = sin_cache + current_half_chunk;
x1sin = cos_cache + current_half_chunk;
x0cos = x1sin + current_half_chunk;
x0sin = x0cos + current_half_chunk;
x1cos = x0sin + current_half_chunk;
input_0 = x1cos + current_half_chunk;
input_1 = input_0 + current_half_chunk;
input_cache = input_1 + current_half_chunk;
calculateRope<Tdata>(
y, x, sin_table, cos_table,
sin_cache, cos_cache, x1sin, x0cos, x0sin, x1cos,
input_0, input_1, input_cache,
theta_offset, dst_offset, src_offset,
current_chunk_size, current_half_chunk,
data_segsize,
src_load_stride, dst_load_stride, src_write_stride, dst_write_stride);
processed += current_half_chunk;
}
}
}
#include "rope_v2_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
namespace op::rope_v2::cpu {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto info = RoPEv2Info::createRoPEv2Info(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
nullptr,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPEv2(const RoPEv2Info &info,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table) {
#pragma omp parallel for
for (ptrdiff_t h = 0; h < ptrdiff_t(info.nhead); h++) {
for (size_t tok = 0; tok < info.seqlen; tok++) {
size_t x_offset = tok * info.x_stride_seqlen + h * info.x_stride_nhead;
size_t y_offset = tok * info.y_stride_seqlen + h * info.y_stride_nhead;
size_t pos_id = size_t(pos_ids[tok]);
size_t table_offset = pos_id * info.table_dim;
size_t half_dim = info.table_dim; // head_dim = 2 * half_dim
for (size_t i = 0; i < info.table_dim; i++) {
// Pair elements from first half and second half
size_t pos0 = i;
size_t pos1 = i + half_dim;
if constexpr (std::is_same<Tdata, fp16_t>::value || std::is_same<Tdata, bf16_t>::value) {
float x0 = utils::cast<float>(x[x_offset + pos0]),
x1 = utils::cast<float>(x[x_offset + pos1]),
sin__ = utils::cast<float>(sin_table[table_offset + i]),
cos__ = utils::cast<float>(cos_table[table_offset + i]);
y[y_offset + pos0] = utils::cast<Tdata>(x0 * cos__ - x1 * sin__);
y[y_offset + pos1] = utils::cast<Tdata>(x0 * sin__ + x1 * cos__);
} else {
Tdata x0 = x[x_offset + pos0],
x1 = x[x_offset + pos1],
sin__ = sin_table[table_offset + i],
cos__ = cos_table[table_offset + i];
y[y_offset + pos0] = x0 * cos__ - x1 * sin__;
y[y_offset + pos1] = x0 * sin__ + x1 * cos__;
}
}
}
}
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE_V2(TDATA, TINDEX) \
calculateRoPEv2(_info, (TDATA *)y, (const TDATA *)x, (const TINDEX *)pos_ids, (const TDATA *)sin_table, (const TDATA *)cos_table)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE_V2(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE_V2(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE_V2(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE_V2(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE_V2(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE_V2(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE_V2(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE_V2(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(fp16_t);
case INFINI_DTYPE_BF16:
ROPE_TYPE(bf16_t);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
ROPE_TYPE(double);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope_v2::cpu
#ifndef __INFINIOP_ROPE_V2_CPU_H__
#define __INFINIOP_ROPE_V2_CPU_H__
#include "../rope_v2.h"
DESCRIPTOR(cpu)
#endif // __INFINIOP_ROPE_V2_CPU_H__
#ifndef __INFINIOP_ROPE_V2_CUDA_KERNEL_CUH__
#define __INFINIOP_ROPE_V2_CUDA_KERNEL_CUH__
template <typename Tdata, typename Tindex, typename Tangle>
__device__ void ropeThreadPerItemBlock(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
auto y_offset = blockIdx.x * y_stride_seqlen + blockIdx.y * y_stride_nhead;
auto x_offset = blockIdx.x * x_stride_seqlen + blockIdx.y * x_stride_nhead;
size_t pos_id = size_t(pos_ids[blockIdx.x]);
auto table_offset = pos_id * table_dim;
const size_t half_dim = table_dim; // Head dimension = 2 * table_dim
for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) {
Tangle sin__ = sin_table[table_offset + i];
Tangle cos__ = cos_table[table_offset + i];
// Calculate positions in first and second halves
size_t pos0 = i;
size_t pos1 = i + half_dim;
if constexpr (std::is_same<Tdata, half>::value) {
Tangle x0 = __half2float(x_[x_offset + pos0]);
Tangle x1 = __half2float(x_[x_offset + pos1]);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y_[y_offset + pos0] = __float2half(y0);
y_[y_offset + pos1] = __float2half(y1);
} else if constexpr (std::is_same<Tdata, cuda_bfloat16>::value) {
Tangle x0 = __bfloat162float(x_[x_offset + pos0]);
Tangle x1 = __bfloat162float(x_[x_offset + pos1]);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y_[y_offset + pos0] = __float2bfloat16(y0);
y_[y_offset + pos1] = __float2bfloat16(y1);
} else {
Tangle x0 = x_[x_offset + pos0];
Tangle x1 = x_[x_offset + pos1];
y_[y_offset + pos0] = x0 * cos__ - x1 * sin__;
y_[y_offset + pos1] = x0 * sin__ + x1 * cos__;
}
}
}
#endif
#ifndef __INFINIOP_ROPE_METAX_H__
#define __INFINIOP_ROPE_METAX_H__
#include "../rope.h"
DESCRIPTOR(metax)
#endif // __INFINIOP_ROPE_METAX_H__
#include "../../../devices/metax/metax_common.h"
#include "rope_metax.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
template <typename Tdata, typename Tindex, typename Tangle>
INFINIOP_METAX_KERNEL ropeThreadPerItemKernel(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ropeThreadPerItemBlock(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
y_stride_seqlen, y_stride_nhead,
x_stride_seqlen, x_stride_nhead);
}
namespace op::rope::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPE(const RoPEInfo &info,
int block_size,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
hcStream_t stream) {
auto dimx = uint32_t(info.seqlen),
dimy = uint32_t(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size);
ropeThreadPerItemKernel<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE(TDATA, TINDEX) \
calculateRoPE(_info, \
_opaque->internal->maxThreadsPerBlock(), \
(TDATA *)y, \
(const TDATA *)x, \
(const TINDEX *)pos_ids, \
(const TDATA *)sin_table, \
(const TDATA *)cos_table, \
(hcStream_t)stream)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(half);
case INFINI_DTYPE_BF16:
ROPE_TYPE(cuda_bfloat16);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
ROPE_TYPE(double);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope::metax
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "rope_v2_nvidia.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
namespace op::rope_v2::nvidia {
template <typename Tdata, typename Tindex, typename Tangle>
INFINIOP_CUDA_KERNEL ropev2ThreadPerItemKernel(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ropeThreadPerItemBlock(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
y_stride_seqlen, y_stride_nhead,
x_stride_seqlen, x_stride_nhead);
}
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::nvidia::Handle *>(handle_);
auto info = RoPEv2Info::createRoPEv2Info(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPEv2(const RoPEv2Info &info,
int block_size,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
cudaStream_t stream) {
auto dimx = uint32_t(info.seqlen),
dimy = uint32_t(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size);
ropev2ThreadPerItemKernel<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE_V2(TDATA, TINDEX) \
calculateRoPEv2(_info, \
_opaque->internal->maxThreadsPerBlock(), \
(TDATA *)y, \
(const TDATA *)x, \
(const TINDEX *)pos_ids, \
(const TDATA *)sin_table, \
(const TDATA *)cos_table, \
(cudaStream_t)stream)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE_V2(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE_V2(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE_V2(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE_V2(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE_V2(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE_V2(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE_V2(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE_V2(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(half);
case INFINI_DTYPE_BF16:
ROPE_TYPE(cuda_bfloat16);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
ROPE_TYPE(double);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope_v2::nvidia
#ifndef __INFINIOP_ROPE_V2_CUDA_H__
#define __INFINIOP_ROPE_V2_CUDA_H__
#include "../rope_v2.h"
DESCRIPTOR(nvidia)
#endif // __INFINIOP_ROPE_V2_CUDA_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/rope_v2.h"
#ifdef ENABLE_CPU_API
#include "cpu/rope_v2_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API)
#include "nvidia/rope_v2_nvidia.cuh"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/rope_v2_ascend.h"
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/rope_v2_bang.h"
#endif
#ifdef ENABLE_METAX_API
#include "metax/rope_v2_metax.h"
#endif
__C infiniStatus_t infiniopCreateRoPEv2Descriptor(
infiniopHandle_t handle,
infiniopRoPEv2Descriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::rope_v2::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::rope_v2::NAMESPACE::Descriptor **>(desc_ptr), \
y, \
x, \
pos_ids, \
sin_table, \
cos_table)
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaCreateRoPEDescriptor((MusaHandle_t)handle,
(RoPEMusaDescriptor_t *)desc_ptr, t,
pos_ids, sin_table, cos_table);
}
#endif
}
#undef CREATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopGetRoPEv2WorkspaceSize(infiniopRoPEv2Descriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::rope_v2::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_CAMBRICON_API
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaGetRoPEWorkspaceSize((RoPEMusaDescriptor_t)desc, size);
}
#endif
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopRoPEv2(
infiniopRoPEv2Descriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::rope_v2::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, x, pos_ids, sin_table, cos_table, stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaRoPE((RoPEMusaDescriptor_t)desc, workspace, workspace_size,
t, pos_ids, sin_table, cos_table, stream);
}
#endif
}
#undef CALCULATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t
infiniopDestroyRoPEv2Descriptor(infiniopRoPEv2Descriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::rope_v2::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_CAMBRICON_API
DELETE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
DELETE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaDestroyRoPEDescriptor((RoPEMusaDescriptor_t)desc);
}
#endif
}
#undef DELETE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#ifndef __ROPE_V2_H__
#define __ROPE_V2_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::rope_v2::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
RoPEv2Info _info; \
size_t _workspace_size; \
\
Descriptor( \
RoPEv2Info info, \
size_t workspace_size_, \
Opaque *opaque, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size_) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc, \
infiniopTensorDescriptor_t pos_desc, \
infiniopTensorDescriptor_t sin_desc, \
infiniopTensorDescriptor_t cos_desc); \
\
infiniStatus_t calculate( \
void *workspace, \
size_t workspace_size, \
void *y, \
const void *x, \
const void *pos_ids, \
const void *sin_table, \
const void *cos_table, \
void *stream) const; \
}; \
}
class RoPEv2Info {
private:
RoPEv2Info() = default;
public:
infiniDtype_t data_type, pos_type;
size_t seqlen, nhead, dhead, table_len, table_dim;
ptrdiff_t
y_stride_seqlen,
y_stride_nhead,
x_stride_seqlen,
x_stride_nhead;
static utils::Result<RoPEv2Info> createRoPEv2Info(
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
CHECK_OR_RETURN(
y_desc != nullptr && pos_desc != nullptr && sin_desc != nullptr && cos_desc != nullptr,
INFINI_STATUS_NULL_POINTER);
const infiniDtype_t data_type = y_desc->dtype();
const infiniDtype_t pos_type = pos_desc->dtype();
CHECK_OR_RETURN(data_type == x_desc->dtype() && data_type == sin_desc->dtype() && data_type == cos_desc->dtype(),
INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE_ANY_INT(pos_type);
CHECK_OR_RETURN(y_desc->ndim() == 3
&& x_desc->ndim() == 3
&& pos_desc->ndim() == 1
&& sin_desc->ndim() == 2
&& cos_desc->ndim() == 2,
INFINI_STATUS_BAD_TENSOR_SHAPE);
const auto seqlen = y_desc->dim(0),
nhead = y_desc->dim(1),
dhead = y_desc->dim(2),
table_len = sin_desc->dim(0),
table_dim = sin_desc->dim(1);
CHECK_OR_RETURN(seqlen == x_desc->dim(0)
&& seqlen == pos_desc->dim(0)
&& nhead == x_desc->dim(1) && dhead == x_desc->dim(2)
&& table_len == cos_desc->dim(0) && table_dim == cos_desc->dim(1),
INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(dhead == table_dim * 2, INFINI_STATUS_BAD_TENSOR_SHAPE);
// Last dimension of x and y must be contiguous
CHECK_OR_RETURN(y_desc->stride(2) == 1 && x_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
// sin table and cos table must be totally contiguous
CHECK_OR_RETURN(sin_desc->isContiguous() && cos_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES);
return utils::Result<RoPEv2Info>(RoPEv2Info{
data_type,
pos_type,
seqlen,
nhead,
dhead,
table_len,
table_dim,
y_desc->stride(0),
y_desc->stride(1),
x_desc->stride(0),
x_desc->stride(1),
});
}
};
#endif
......@@ -2,27 +2,48 @@ from ast import List
import numpy as np
import gguf
from typing import List
from enum import Enum
from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides
class Algorithm(Enum):
GPT_J = 0
GPT_NEOX = 1
def rotary_embedding(t, sin, cos):
dh = t.shape[2]
assert dh % 2 == 0, "Embedding dimension must be even."
t_even = t[..., 0::2] # [seq_len, n_head, dh // 2]
t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
def rotary_embedding(t, sin, cos, algo):
def _rope(sin, cos, t1, t2):
cos = np.expand_dims(cos, axis=1) # [seq_len, 1, dh // 2]
sin = np.expand_dims(sin, axis=1) # [seq_len, 1, dh // 2]
cos = np.expand_dims(cos, axis=1) # [seq_len, 1, dh // 2]
sin = np.expand_dims(sin, axis=1) # [seq_len, 1, dh // 2]
t_out_1 = t1 * cos - t2 * sin
t_out_2 = t1 * sin + t2 * cos
t_out_even = t_even * cos - t_odd * sin
t_out_odd = t_even * sin + t_odd * cos
return t_out_1, t_out_2
dh = t.shape[-1]
assert dh % 2 == 0, "Embedding dimension must be even."
t_out = np.empty_like(t)
t_out[..., 0::2] = t_out_even
t_out[..., 1::2] = t_out_odd
if algo == Algorithm.GPT_J.value:
t_even = t[..., 0::2] # [seq_len, n_head, dh // 2]
t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
t_out_even, t_out_odd = _rope(sin, cos, t_even, t_odd)
t_out[..., 0::2] = t_out_even
t_out[..., 1::2] = t_out_odd
else:
half_dim = dh // 2
t_first = t[..., :half_dim]
t_second = t[..., half_dim:]
t_out_first, t_out_second = _rope(sin, cos, t_first, t_second)
t_out[..., :half_dim] = t_out_first
t_out[..., half_dim:] = t_out_second
return t_out
......@@ -52,6 +73,7 @@ class RoPETestCase(InfiniopTestCase):
pos_ids: np.ndarray,
sin_table: np.ndarray,
cos_table: np.ndarray,
algo: int,
):
super().__init__("rope")
self.y = y
......@@ -63,10 +85,12 @@ class RoPETestCase(InfiniopTestCase):
self.pos_ids = pos_ids
self.sin_table = sin_table
self.cos_table = cos_table
self.algo = algo
def write_test(self, test_writer: "InfiniopTestWriter"):
super().write_test(test_writer)
test_writer.add_int32(test_writer.gguf_key("algo"), self.algo)
test_writer.add_tensor(
test_writer.gguf_key("y"), self.y, raw_dtype=np_dtype_to_ggml(self.y.dtype)
)
......@@ -97,6 +121,7 @@ class RoPETestCase(InfiniopTestCase):
self.x.astype(np.float64),
self.sin_table.astype(np.float64),
self.cos_table.astype(np.float64),
self.algo,
)
test_writer.add_tensor(
test_writer.gguf_key("ans"), ans, raw_dtype=gguf.GGMLQuantizationType.F64
......@@ -121,27 +146,35 @@ if __name__ == "__main__":
((3, 32, 128), (8000, 200, 1), (7000, 128, 1)),
]
_ALGO = [
Algorithm.GPT_J,
Algorithm.GPT_NEOX,
]
_TENSOR_DTYPES_ = [np.float16, np.float32]
test_writer = InfiniopTestWriter("rope.gguf")
test_cases = []
for dtype in _TENSOR_DTYPES_:
for shape, stride_x, stride_y in _TEST_CASES_:
x = np.random.rand(*shape).astype(dtype)
y = np.empty(tuple(0 for _ in shape), dtype=dtype)
pos_ids = np.arange(0, x.shape[0], dtype=np.int32)
sin_table, cos_table = sin_cos_table(pos_ids, x.shape[2], theta=1e5, dtype=dtype)
test_case = RoPETestCase(
y=y,
x=x,
shape_y=shape,
shape_x=shape,
stride_y=stride_y,
stride_x=stride_x,
pos_ids=pos_ids,
sin_table=sin_table,
cos_table=cos_table,
)
test_cases.append(test_case)
for algo in _ALGO:
for dtype in _TENSOR_DTYPES_:
for shape, stride_x, stride_y in _TEST_CASES_:
x = np.random.rand(*shape).astype(dtype)
y = np.empty(tuple(0 for _ in shape), dtype=dtype)
pos_ids = np.arange(0, x.shape[0], dtype=np.int32)
sin_table, cos_table = sin_cos_table(pos_ids, x.shape[2], theta=1e5, dtype=dtype)
test_case = RoPETestCase(
y=y,
x=x,
shape_y=shape,
shape_x=shape,
stride_y=stride_y,
stride_x=stride_x,
pos_ids=pos_ids,
sin_table=sin_table,
cos_table=cos_table,
algo=algo.value,
)
test_cases.append(test_case)
test_writer.add_tests(test_cases)
test_writer.save()
......@@ -361,6 +361,8 @@ def rope_(lib):
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_int32,
]
lib.infiniopGetRoPEWorkspaceSize.restype = c_int32
......@@ -379,6 +381,7 @@ def rope_(lib):
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyRoPEDescriptor.restype = c_int32
......@@ -387,42 +390,6 @@ def rope_(lib):
]
@OpRegister.operator
def rope_v2_(lib):
lib.infiniopCreateRoPEv2Descriptor.restype = c_int32
lib.infiniopCreateRoPEv2Descriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetRoPEv2WorkspaceSize.restype = c_int32
lib.infiniopGetRoPEv2WorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopRoPEv2.restype = c_int32
lib.infiniopRoPEv2.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyRoPEv2Descriptor.restype = c_int32
lib.infiniopDestroyRoPEv2Descriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def sub_(lib):
lib.infiniopCreateSubDescriptor.restype = c_int32
......
......@@ -51,15 +51,27 @@ class Inplace(Enum):
INPLACE_X = auto()
class Algorithm(Enum):
GPT_J = 0
GPT_NEOX = 1
_INPLACE = [
Inplace.OUT_OF_PLACE,
Inplace.INPLACE_X,
]
_ALGO = [
Algorithm.GPT_J,
Algorithm.GPT_NEOX,
]
_TEST_CASES = [
test_case + (inplace_item,)
test_case + (inplace_item, algo_item)
for test_case in _TEST_CASES_
for inplace_item in _INPLACE
for algo_item in _ALGO
]
DEBUG = False
......@@ -68,27 +80,45 @@ NUM_PRERUN = 10
NUM_ITERATIONS = 1000
def rotary_embedding(ans, t, sin, cos, device):
dh = t.shape[2]
def rotary_embedding(ans, t, sin, cos, device, algo):
def _torch_rope(sin, cos, t1, t2):
cos = cos.unsqueeze(1) # [seq_len, 1, dh // 2]
sin = sin.unsqueeze(1) # [seq_len, 1, dh // 2]
if device == InfiniDeviceEnum.CPU:
(t1, t2, cos, sin) = (
t1.float(),
t2.float(),
cos.float(),
sin.float(),
)
t_out_1 = t1 * cos - t2 * sin
t_out_2 = t1 * sin + t2 * cos
return t_out_1, t_out_2
dh = t.shape[-1]
dt = t.dtype
assert dh % 2 == 0, "Embedding dimension must be even."
t_even = t[..., 0::2] # [seq_len, n_head, dh // 2]
t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
cos = cos.unsqueeze(1) # [seq_len, 1, dh // 2]
sin = sin.unsqueeze(1) # [seq_len, 1, dh // 2]
if device == InfiniDeviceEnum.CPU:
(t_even, t_odd, cos, sin) = (
t_even.float(),
t_odd.float(),
cos.float(),
sin.float(),
)
t_out_even = t_even * cos - t_odd * sin
t_out_odd = t_even * sin + t_odd * cos
if algo == Algorithm.GPT_J:
t_even = t[..., 0::2] # [seq_len, n_head, dh // 2]
t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
t_out_even, t_out_odd = _torch_rope(sin, cos, t_even, t_odd)
ans[..., 0::2] = t_out_even.to(dt)
ans[..., 1::2] = t_out_odd.to(dt)
else:
half_dim = dh // 2
t_first = t[..., :half_dim]
t_second = t[..., half_dim:]
t_out_first, t_out_second = _torch_rope(sin, cos, t_first, t_second)
ans[..., 0::2] = t_out_even.to(dt)
ans[..., 1::2] = t_out_odd.to(dt)
ans[..., :half_dim] = t_out_first.to(dt)
ans[..., half_dim:] = t_out_second.to(dt)
def sin_cos_table(pos, dim, device, theta, dtype):
......@@ -108,6 +138,7 @@ def test(
x_strides=None,
y_strides=None,
inplace=Inplace.OUT_OF_PLACE,
algo=Algorithm.GPT_J,
dtype=torch.float32,
sync=None,
):
......@@ -120,7 +151,7 @@ def test(
y = TestTensor(shape, y_strides, dtype, device)
print(
f"Testing Rotary Positional Embedding on {InfiniDeviceNames[device]} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}"
f"Testing Rotary Positional Embedding on {InfiniDeviceNames[device]} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{InfiniDtypeNames[dtype]} inplace:{inplace} algo:{algo}"
)
theta = 1e5
pos = TestTensor.from_torch(torch.arange(0, x.shape[0]), InfiniDtype.I32, device)
......@@ -134,6 +165,7 @@ def test(
sin_table.torch_tensor(),
cos_table.torch_tensor(),
device,
algo,
)
descriptor = infiniopOperatorDescriptor_t()
......@@ -150,6 +182,7 @@ def test(
pos.descriptor,
sin_table.descriptor,
cos_table.descriptor,
algo.value,
)
)
......
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceEnum,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
)
from enum import Enum, auto
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# (shape, x_strides, y_strides)
((1, 32, 128), None, None),
((10, 32, 64), None, None),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), (64, 64, 1), None),
((11, 33, 128), None, (8000, 200, 1)),
((3, 32, 128), (8000, 200, 1), (7000, 128, 1)),
]
# Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2},
InfiniDtype.F32: {"atol": 1e-4, "rtol": 1e-3},
}
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_X = auto()
_INPLACE = [
Inplace.OUT_OF_PLACE,
Inplace.INPLACE_X,
]
_TEST_CASES = [
test_case + (inplace_item,)
for test_case in _TEST_CASES_
for inplace_item in _INPLACE
]
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
def rotary_embedding(ans, t, sin, cos, device):
dh = t.shape[-1]
dt = t.dtype
assert dh % 2 == 0, "Embedding dimension must be even."
half_dim = dh // 2
t_first = t[..., :half_dim]
t_second = t[..., half_dim:]
cos = cos.unsqueeze(1) # [seq_len, 1, half_dim]
sin = sin.unsqueeze(1) # [seq_len, 1, half_dim]
if device == InfiniDeviceEnum.CPU:
t_first = t_first.float()
t_second = t_second.float()
cos = cos.float()
sin = sin.float()
t_out_first = t_first * cos - t_second * sin
t_out_second = t_first * sin + t_second * cos
ans[..., :half_dim] = t_out_first.to(dt)
ans[..., half_dim:] = t_out_second.to(dt)
def sin_cos_table(pos, dim, device, theta, dtype):
assert dim % 2 == 0, "Embedding dimension must be even."
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
angles = torch.outer(pos.cpu(), freqs)
return (
TestTensor.from_torch(torch.sin(angles), dtype, device),
TestTensor.from_torch(torch.cos(angles), dtype, device),
)
def test(
handle,
device,
shape,
x_strides=None,
y_strides=None,
inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float32,
sync=None,
):
x = TestTensor(shape, x_strides, dtype, device)
if inplace == Inplace.INPLACE_X:
if x_strides != y_strides:
return
y = x
else:
y = TestTensor(shape, y_strides, dtype, device)
print(
f"Testing Rotary Positional Embedding on {InfiniDeviceNames[device]} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}"
)
theta = 1e5
pos = TestTensor.from_torch(torch.arange(0, x.shape[0]), InfiniDtype.I32, device)
sin_table, cos_table = sin_cos_table(
pos.torch_tensor(), x.shape[2], x.device, theta, dtype
)
rotary_embedding(
y.torch_tensor(),
x.torch_tensor(),
sin_table.torch_tensor(),
cos_table.torch_tensor(),
device,
)
descriptor = infiniopOperatorDescriptor_t()
if sync is not None:
sync()
check_error(
LIBINFINIOP.infiniopCreateRoPEv2Descriptor(
handle,
ctypes.byref(descriptor),
y.descriptor,
x.descriptor,
pos.descriptor,
sin_table.descriptor,
cos_table.descriptor,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [y, x, pos, sin_table, cos_table]:
tensor.destroy_desc()
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetRoPEv2WorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, x.device)
def lib_rope_v2():
check_error(
LIBINFINIOP.infiniopRoPEv2(
descriptor,
workspace.data(),
workspace_size.value,
y.data(),
x.data(),
pos.data(),
sin_table.data(),
cos_table.data(),
None,
)
)
lib_rope_v2()
if sync is not None:
sync()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
if PROFILE:
profile_operation(
"PyTorch",
lambda: rotary_embedding(
y.torch_tensor(),
x.torch_tensor(),
sin_table.torch_tensor(),
cos_table.torch_tensor(),
device,
),
device,
NUM_PRERUN,
NUM_ITERATIONS,
)
profile_operation(
" lib", lambda: lib_rope_v2(), device, NUM_PRERUN, NUM_ITERATIONS
)
check_error(LIBINFINIOP.infiniopDestroyRoPEv2Descriptor(descriptor))
if __name__ == "__main__":
args = get_args()
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
# Execute tests
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment