Unverified Commit f9d16628 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #429 from InfiniTensor/issue/428_merge_rope_and_rope_v2

Issue/428: Merge `rope_v2` into `rope`
parents 15ac0191 9f0ae734
#include "../../../devices/bang/common_bang.h"
__nram__ char nram_buffer[NRAM_MAX_SIZE];
template <typename Tdata>
__mlu_device__ void calculateRope(
Tdata *out, const Tdata *in,
const Tdata *sin_table, const Tdata *cos_table,
Tdata *sin_cache, Tdata *cos_cache,
Tdata *x1sin, Tdata *x0cos, Tdata *x0sin, Tdata *x1cos,
Tdata *input_0, Tdata *input_1, Tdata *input_cache,
int theta_index, int out_index, int in_index,
int chunk_size, int half_chunk_size, int data_segsize,
int src_load_stride, int dst_load_stride, int src_write_stride, int dst_write_stride) {
// Load sin/cos data
__memcpy(sin_cache, sin_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
__memcpy(cos_cache, cos_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
// Load input data
__memcpy(input_cache, in + in_index, chunk_size * sizeof(Tdata), GDRAM2NRAM);
// Split input into even and odd positions
__memcpy(input_0, input_cache, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
__memcpy(input_1, input_cache + 1, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
// Compute even positions: y0 = x0 * cos - x1 * sin and y1 = x0 * sin + x1 * cos
__bang_mul(x0cos, input_0, cos_cache, half_chunk_size);
__bang_mul(x1sin, input_1, sin_cache, half_chunk_size);
__bang_mul(x0sin, input_0, sin_cache, half_chunk_size);
__bang_mul(x1cos, input_1, cos_cache, half_chunk_size);
__bang_sub(input_0, x0cos, x1sin, half_chunk_size);
__bang_add(input_1, x0sin, x1cos, half_chunk_size);
// Interleave results back into output buffer
__memcpy(input_cache, input_0, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
__memcpy(input_cache + 1, input_1, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
// Write back results
__memcpy(out + out_index, input_cache, chunk_size * sizeof(Tdata), NRAM2GDRAM);
}
template <typename Tdata, typename Tindex>
__mlu_global__ void ropeKernel(
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
uint32_t seqlen,
uint32_t nhead,
uint32_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
// Calculate available NRAM space after alignment
const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * 9); // 9 buffers need alignment
const size_t max_chunk_elements = nram_usable / (9 * sizeof(Tdata));
// Key variables that determine execution path
const bool use_pos_ids_buffer = (seqlen * sizeof(Tindex) <= (nram_usable / 2));
const int half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
// Common stride configurations
const int data_segsize = sizeof(Tdata);
const int src_load_stride = 2 * sizeof(Tdata);
const int dst_load_stride = 1 * sizeof(Tdata);
const int src_write_stride = 1 * sizeof(Tdata);
const int dst_write_stride = 2 * sizeof(Tdata);
// Task distribution
const int batch_volume = seqlen * nhead;
const int remaining_tasks = batch_volume % taskDim;
const int base_tasks_per_core = batch_volume / taskDim;
const int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
const int task_start_idx = (taskId < remaining_tasks ? taskId * base_tasks_per_core + taskId : taskId * base_tasks_per_core + remaining_tasks);
// NRAM buffer allocation with proper alignment
char *aligned_nram = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
// Setup position IDs if they fit in NRAM
Tindex *srcP = nullptr;
if (use_pos_ids_buffer) {
srcP = (Tindex *)aligned_nram;
__memcpy(srcP, pos_ids, seqlen * sizeof(Tindex), GDRAM2NRAM);
aligned_nram = (char *)(((size_t)srcP + seqlen * sizeof(Tindex) + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
}
// Main processing buffers (pointers will be set per chunk)
Tdata *sin_cache = nullptr;
Tdata *cos_cache = nullptr;
Tdata *x1sin = nullptr;
Tdata *x0cos = nullptr;
Tdata *x0sin = nullptr;
Tdata *x1cos = nullptr;
Tdata *input_0 = nullptr;
Tdata *input_1 = nullptr;
Tdata *input_cache = nullptr;
// Main processing loop
for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
// Calculate output and input indices
int seq_idx = i / nhead;
int head_idx = i % nhead;
// Output indices (y)
int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
// Input indices (x)
int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
// Get position index
Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
int rot_offset = pos_idx * table_dim;
// Process in chunks that fit in NRAM
int processed = 0;
while (processed < table_dim) {
// Calculate current chunk size
int current_half_chunk = std::min<uint32_t>(half_chunk_size, table_dim - processed);
int current_chunk_size = 2 * current_half_chunk;
int theta_offset = rot_offset + processed;
int dst_offset = out_offset + processed * 2;
int src_offset = in_offset + processed * 2;
// Set up NRAM buffers for this chunk
char *chunk_base = aligned_nram;
sin_cache = (Tdata *)chunk_base;
cos_cache = sin_cache + current_half_chunk;
x1sin = cos_cache + current_half_chunk;
x0cos = x1sin + current_half_chunk;
x0sin = x0cos + current_half_chunk;
x1cos = x0sin + current_half_chunk;
input_0 = x1cos + current_half_chunk;
input_1 = input_0 + current_half_chunk;
input_cache = input_1 + current_half_chunk;
calculateRope<Tdata>(
y, x, sin_table, cos_table,
sin_cache, cos_cache, x1sin, x0cos, x0sin, x1cos,
input_0, input_1, input_cache,
theta_offset, dst_offset, src_offset,
current_chunk_size, current_half_chunk,
data_segsize,
src_load_stride, dst_load_stride, src_write_stride, dst_write_stride);
processed += current_half_chunk;
}
}
}
#include "rope_v2_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
namespace op::rope_v2::cpu {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto info = RoPEv2Info::createRoPEv2Info(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
nullptr,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPEv2(const RoPEv2Info &info,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table) {
#pragma omp parallel for
for (ptrdiff_t h = 0; h < ptrdiff_t(info.nhead); h++) {
for (size_t tok = 0; tok < info.seqlen; tok++) {
size_t x_offset = tok * info.x_stride_seqlen + h * info.x_stride_nhead;
size_t y_offset = tok * info.y_stride_seqlen + h * info.y_stride_nhead;
size_t pos_id = size_t(pos_ids[tok]);
size_t table_offset = pos_id * info.table_dim;
size_t half_dim = info.table_dim; // head_dim = 2 * half_dim
for (size_t i = 0; i < info.table_dim; i++) {
// Pair elements from first half and second half
size_t pos0 = i;
size_t pos1 = i + half_dim;
if constexpr (std::is_same<Tdata, fp16_t>::value || std::is_same<Tdata, bf16_t>::value) {
float x0 = utils::cast<float>(x[x_offset + pos0]),
x1 = utils::cast<float>(x[x_offset + pos1]),
sin__ = utils::cast<float>(sin_table[table_offset + i]),
cos__ = utils::cast<float>(cos_table[table_offset + i]);
y[y_offset + pos0] = utils::cast<Tdata>(x0 * cos__ - x1 * sin__);
y[y_offset + pos1] = utils::cast<Tdata>(x0 * sin__ + x1 * cos__);
} else {
Tdata x0 = x[x_offset + pos0],
x1 = x[x_offset + pos1],
sin__ = sin_table[table_offset + i],
cos__ = cos_table[table_offset + i];
y[y_offset + pos0] = x0 * cos__ - x1 * sin__;
y[y_offset + pos1] = x0 * sin__ + x1 * cos__;
}
}
}
}
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE_V2(TDATA, TINDEX) \
calculateRoPEv2(_info, (TDATA *)y, (const TDATA *)x, (const TINDEX *)pos_ids, (const TDATA *)sin_table, (const TDATA *)cos_table)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE_V2(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE_V2(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE_V2(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE_V2(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE_V2(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE_V2(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE_V2(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE_V2(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(fp16_t);
case INFINI_DTYPE_BF16:
ROPE_TYPE(bf16_t);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
ROPE_TYPE(double);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope_v2::cpu
#ifndef __INFINIOP_ROPE_V2_CPU_H__
#define __INFINIOP_ROPE_V2_CPU_H__
#include "../rope_v2.h"
DESCRIPTOR(cpu)
#endif // __INFINIOP_ROPE_V2_CPU_H__
#ifndef __INFINIOP_ROPE_V2_CUDA_KERNEL_CUH__
#define __INFINIOP_ROPE_V2_CUDA_KERNEL_CUH__
template <typename Tdata, typename Tindex, typename Tangle>
__device__ void ropeThreadPerItemBlock(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
auto y_offset = blockIdx.x * y_stride_seqlen + blockIdx.y * y_stride_nhead;
auto x_offset = blockIdx.x * x_stride_seqlen + blockIdx.y * x_stride_nhead;
size_t pos_id = size_t(pos_ids[blockIdx.x]);
auto table_offset = pos_id * table_dim;
const size_t half_dim = table_dim; // Head dimension = 2 * table_dim
for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) {
Tangle sin__ = sin_table[table_offset + i];
Tangle cos__ = cos_table[table_offset + i];
// Calculate positions in first and second halves
size_t pos0 = i;
size_t pos1 = i + half_dim;
if constexpr (std::is_same<Tdata, half>::value) {
Tangle x0 = __half2float(x_[x_offset + pos0]);
Tangle x1 = __half2float(x_[x_offset + pos1]);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y_[y_offset + pos0] = __float2half(y0);
y_[y_offset + pos1] = __float2half(y1);
} else if constexpr (std::is_same<Tdata, cuda_bfloat16>::value) {
Tangle x0 = __bfloat162float(x_[x_offset + pos0]);
Tangle x1 = __bfloat162float(x_[x_offset + pos1]);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y_[y_offset + pos0] = __float2bfloat16(y0);
y_[y_offset + pos1] = __float2bfloat16(y1);
} else {
Tangle x0 = x_[x_offset + pos0];
Tangle x1 = x_[x_offset + pos1];
y_[y_offset + pos0] = x0 * cos__ - x1 * sin__;
y_[y_offset + pos1] = x0 * sin__ + x1 * cos__;
}
}
}
#endif
#ifndef __INFINIOP_ROPE_METAX_H__
#define __INFINIOP_ROPE_METAX_H__
#include "../rope.h"
DESCRIPTOR(metax)
#endif // __INFINIOP_ROPE_METAX_H__
#include "../../../devices/metax/metax_common.h"
#include "rope_metax.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
template <typename Tdata, typename Tindex, typename Tangle>
INFINIOP_METAX_KERNEL ropeThreadPerItemKernel(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ropeThreadPerItemBlock(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
y_stride_seqlen, y_stride_nhead,
x_stride_seqlen, x_stride_nhead);
}
namespace op::rope::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPE(const RoPEInfo &info,
int block_size,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
hcStream_t stream) {
auto dimx = uint32_t(info.seqlen),
dimy = uint32_t(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size);
ropeThreadPerItemKernel<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE(TDATA, TINDEX) \
calculateRoPE(_info, \
_opaque->internal->maxThreadsPerBlock(), \
(TDATA *)y, \
(const TDATA *)x, \
(const TINDEX *)pos_ids, \
(const TDATA *)sin_table, \
(const TDATA *)cos_table, \
(hcStream_t)stream)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(half);
case INFINI_DTYPE_BF16:
ROPE_TYPE(cuda_bfloat16);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
ROPE_TYPE(double);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope::metax
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "rope_v2_nvidia.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
namespace op::rope_v2::nvidia {
template <typename Tdata, typename Tindex, typename Tangle>
INFINIOP_CUDA_KERNEL ropev2ThreadPerItemKernel(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ropeThreadPerItemBlock(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
y_stride_seqlen, y_stride_nhead,
x_stride_seqlen, x_stride_nhead);
}
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::nvidia::Handle *>(handle_);
auto info = RoPEv2Info::createRoPEv2Info(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPEv2(const RoPEv2Info &info,
int block_size,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
cudaStream_t stream) {
auto dimx = uint32_t(info.seqlen),
dimy = uint32_t(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size);
ropev2ThreadPerItemKernel<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE_V2(TDATA, TINDEX) \
calculateRoPEv2(_info, \
_opaque->internal->maxThreadsPerBlock(), \
(TDATA *)y, \
(const TDATA *)x, \
(const TINDEX *)pos_ids, \
(const TDATA *)sin_table, \
(const TDATA *)cos_table, \
(cudaStream_t)stream)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE_V2(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE_V2(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE_V2(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE_V2(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE_V2(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE_V2(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE_V2(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE_V2(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(half);
case INFINI_DTYPE_BF16:
ROPE_TYPE(cuda_bfloat16);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
ROPE_TYPE(double);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope_v2::nvidia
#ifndef __INFINIOP_ROPE_V2_CUDA_H__
#define __INFINIOP_ROPE_V2_CUDA_H__
#include "../rope_v2.h"
DESCRIPTOR(nvidia)
#endif // __INFINIOP_ROPE_V2_CUDA_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/rope_v2.h"
#ifdef ENABLE_CPU_API
#include "cpu/rope_v2_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API)
#include "nvidia/rope_v2_nvidia.cuh"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/rope_v2_ascend.h"
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/rope_v2_bang.h"
#endif
#ifdef ENABLE_METAX_API
#include "metax/rope_v2_metax.h"
#endif
__C infiniStatus_t infiniopCreateRoPEv2Descriptor(
infiniopHandle_t handle,
infiniopRoPEv2Descriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::rope_v2::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::rope_v2::NAMESPACE::Descriptor **>(desc_ptr), \
y, \
x, \
pos_ids, \
sin_table, \
cos_table)
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaCreateRoPEDescriptor((MusaHandle_t)handle,
(RoPEMusaDescriptor_t *)desc_ptr, t,
pos_ids, sin_table, cos_table);
}
#endif
}
#undef CREATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopGetRoPEv2WorkspaceSize(infiniopRoPEv2Descriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::rope_v2::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_CAMBRICON_API
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaGetRoPEWorkspaceSize((RoPEMusaDescriptor_t)desc, size);
}
#endif
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopRoPEv2(
infiniopRoPEv2Descriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::rope_v2::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, x, pos_ids, sin_table, cos_table, stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaRoPE((RoPEMusaDescriptor_t)desc, workspace, workspace_size,
t, pos_ids, sin_table, cos_table, stream);
}
#endif
}
#undef CALCULATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t
infiniopDestroyRoPEv2Descriptor(infiniopRoPEv2Descriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::rope_v2::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_CAMBRICON_API
DELETE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_ASCEND_API
DELETE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaDestroyRoPEDescriptor((RoPEMusaDescriptor_t)desc);
}
#endif
}
#undef DELETE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#ifndef __ROPE_V2_H__
#define __ROPE_V2_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::rope_v2::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
RoPEv2Info _info; \
size_t _workspace_size; \
\
Descriptor( \
RoPEv2Info info, \
size_t workspace_size_, \
Opaque *opaque, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size_) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc, \
infiniopTensorDescriptor_t pos_desc, \
infiniopTensorDescriptor_t sin_desc, \
infiniopTensorDescriptor_t cos_desc); \
\
infiniStatus_t calculate( \
void *workspace, \
size_t workspace_size, \
void *y, \
const void *x, \
const void *pos_ids, \
const void *sin_table, \
const void *cos_table, \
void *stream) const; \
}; \
}
class RoPEv2Info {
private:
RoPEv2Info() = default;
public:
infiniDtype_t data_type, pos_type;
size_t seqlen, nhead, dhead, table_len, table_dim;
ptrdiff_t
y_stride_seqlen,
y_stride_nhead,
x_stride_seqlen,
x_stride_nhead;
static utils::Result<RoPEv2Info> createRoPEv2Info(
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
CHECK_OR_RETURN(
y_desc != nullptr && pos_desc != nullptr && sin_desc != nullptr && cos_desc != nullptr,
INFINI_STATUS_NULL_POINTER);
const infiniDtype_t data_type = y_desc->dtype();
const infiniDtype_t pos_type = pos_desc->dtype();
CHECK_OR_RETURN(data_type == x_desc->dtype() && data_type == sin_desc->dtype() && data_type == cos_desc->dtype(),
INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE_ANY_INT(pos_type);
CHECK_OR_RETURN(y_desc->ndim() == 3
&& x_desc->ndim() == 3
&& pos_desc->ndim() == 1
&& sin_desc->ndim() == 2
&& cos_desc->ndim() == 2,
INFINI_STATUS_BAD_TENSOR_SHAPE);
const auto seqlen = y_desc->dim(0),
nhead = y_desc->dim(1),
dhead = y_desc->dim(2),
table_len = sin_desc->dim(0),
table_dim = sin_desc->dim(1);
CHECK_OR_RETURN(seqlen == x_desc->dim(0)
&& seqlen == pos_desc->dim(0)
&& nhead == x_desc->dim(1) && dhead == x_desc->dim(2)
&& table_len == cos_desc->dim(0) && table_dim == cos_desc->dim(1),
INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(dhead == table_dim * 2, INFINI_STATUS_BAD_TENSOR_SHAPE);
// Last dimension of x and y must be contiguous
CHECK_OR_RETURN(y_desc->stride(2) == 1 && x_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
// sin table and cos table must be totally contiguous
CHECK_OR_RETURN(sin_desc->isContiguous() && cos_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES);
return utils::Result<RoPEv2Info>(RoPEv2Info{
data_type,
pos_type,
seqlen,
nhead,
dhead,
table_len,
table_dim,
y_desc->stride(0),
y_desc->stride(1),
x_desc->stride(0),
x_desc->stride(1),
});
}
};
#endif
...@@ -2,27 +2,48 @@ from ast import List ...@@ -2,27 +2,48 @@ from ast import List
import numpy as np import numpy as np
import gguf import gguf
from typing import List from typing import List
from enum import Enum
from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides
class Algorithm(Enum):
GPT_J = 0
GPT_NEOX = 1
def rotary_embedding(t, sin, cos):
dh = t.shape[2]
assert dh % 2 == 0, "Embedding dimension must be even."
t_even = t[..., 0::2] # [seq_len, n_head, dh // 2] def rotary_embedding(t, sin, cos, algo):
t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2] def _rope(sin, cos, t1, t2):
cos = np.expand_dims(cos, axis=1) # [seq_len, 1, dh // 2]
sin = np.expand_dims(sin, axis=1) # [seq_len, 1, dh // 2]
cos = np.expand_dims(cos, axis=1) # [seq_len, 1, dh // 2] t_out_1 = t1 * cos - t2 * sin
sin = np.expand_dims(sin, axis=1) # [seq_len, 1, dh // 2] t_out_2 = t1 * sin + t2 * cos
t_out_even = t_even * cos - t_odd * sin return t_out_1, t_out_2
t_out_odd = t_even * sin + t_odd * cos
dh = t.shape[-1]
assert dh % 2 == 0, "Embedding dimension must be even."
t_out = np.empty_like(t) t_out = np.empty_like(t)
t_out[..., 0::2] = t_out_even
t_out[..., 1::2] = t_out_odd if algo == Algorithm.GPT_J.value:
t_even = t[..., 0::2] # [seq_len, n_head, dh // 2]
t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
t_out_even, t_out_odd = _rope(sin, cos, t_even, t_odd)
t_out[..., 0::2] = t_out_even
t_out[..., 1::2] = t_out_odd
else:
half_dim = dh // 2
t_first = t[..., :half_dim]
t_second = t[..., half_dim:]
t_out_first, t_out_second = _rope(sin, cos, t_first, t_second)
t_out[..., :half_dim] = t_out_first
t_out[..., half_dim:] = t_out_second
return t_out return t_out
...@@ -52,6 +73,7 @@ class RoPETestCase(InfiniopTestCase): ...@@ -52,6 +73,7 @@ class RoPETestCase(InfiniopTestCase):
pos_ids: np.ndarray, pos_ids: np.ndarray,
sin_table: np.ndarray, sin_table: np.ndarray,
cos_table: np.ndarray, cos_table: np.ndarray,
algo: int,
): ):
super().__init__("rope") super().__init__("rope")
self.y = y self.y = y
...@@ -63,10 +85,12 @@ class RoPETestCase(InfiniopTestCase): ...@@ -63,10 +85,12 @@ class RoPETestCase(InfiniopTestCase):
self.pos_ids = pos_ids self.pos_ids = pos_ids
self.sin_table = sin_table self.sin_table = sin_table
self.cos_table = cos_table self.cos_table = cos_table
self.algo = algo
def write_test(self, test_writer: "InfiniopTestWriter"): def write_test(self, test_writer: "InfiniopTestWriter"):
super().write_test(test_writer) super().write_test(test_writer)
test_writer.add_int32(test_writer.gguf_key("algo"), self.algo)
test_writer.add_tensor( test_writer.add_tensor(
test_writer.gguf_key("y"), self.y, raw_dtype=np_dtype_to_ggml(self.y.dtype) test_writer.gguf_key("y"), self.y, raw_dtype=np_dtype_to_ggml(self.y.dtype)
) )
...@@ -97,6 +121,7 @@ class RoPETestCase(InfiniopTestCase): ...@@ -97,6 +121,7 @@ class RoPETestCase(InfiniopTestCase):
self.x.astype(np.float64), self.x.astype(np.float64),
self.sin_table.astype(np.float64), self.sin_table.astype(np.float64),
self.cos_table.astype(np.float64), self.cos_table.astype(np.float64),
self.algo,
) )
test_writer.add_tensor( test_writer.add_tensor(
test_writer.gguf_key("ans"), ans, raw_dtype=gguf.GGMLQuantizationType.F64 test_writer.gguf_key("ans"), ans, raw_dtype=gguf.GGMLQuantizationType.F64
...@@ -121,27 +146,35 @@ if __name__ == "__main__": ...@@ -121,27 +146,35 @@ if __name__ == "__main__":
((3, 32, 128), (8000, 200, 1), (7000, 128, 1)), ((3, 32, 128), (8000, 200, 1), (7000, 128, 1)),
] ]
_ALGO = [
Algorithm.GPT_J,
Algorithm.GPT_NEOX,
]
_TENSOR_DTYPES_ = [np.float16, np.float32] _TENSOR_DTYPES_ = [np.float16, np.float32]
test_writer = InfiniopTestWriter("rope.gguf") test_writer = InfiniopTestWriter("rope.gguf")
test_cases = [] test_cases = []
for dtype in _TENSOR_DTYPES_: for algo in _ALGO:
for shape, stride_x, stride_y in _TEST_CASES_: for dtype in _TENSOR_DTYPES_:
x = np.random.rand(*shape).astype(dtype) for shape, stride_x, stride_y in _TEST_CASES_:
y = np.empty(tuple(0 for _ in shape), dtype=dtype) x = np.random.rand(*shape).astype(dtype)
pos_ids = np.arange(0, x.shape[0], dtype=np.int32) y = np.empty(tuple(0 for _ in shape), dtype=dtype)
sin_table, cos_table = sin_cos_table(pos_ids, x.shape[2], theta=1e5, dtype=dtype) pos_ids = np.arange(0, x.shape[0], dtype=np.int32)
test_case = RoPETestCase( sin_table, cos_table = sin_cos_table(pos_ids, x.shape[2], theta=1e5, dtype=dtype)
y=y, test_case = RoPETestCase(
x=x, y=y,
shape_y=shape, x=x,
shape_x=shape, shape_y=shape,
stride_y=stride_y, shape_x=shape,
stride_x=stride_x, stride_y=stride_y,
pos_ids=pos_ids, stride_x=stride_x,
sin_table=sin_table, pos_ids=pos_ids,
cos_table=cos_table, sin_table=sin_table,
) cos_table=cos_table,
test_cases.append(test_case) algo=algo.value,
)
test_cases.append(test_case)
test_writer.add_tests(test_cases) test_writer.add_tests(test_cases)
test_writer.save() test_writer.save()
...@@ -361,6 +361,8 @@ def rope_(lib): ...@@ -361,6 +361,8 @@ def rope_(lib):
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_int32,
] ]
lib.infiniopGetRoPEWorkspaceSize.restype = c_int32 lib.infiniopGetRoPEWorkspaceSize.restype = c_int32
...@@ -379,6 +381,7 @@ def rope_(lib): ...@@ -379,6 +381,7 @@ def rope_(lib):
c_void_p, c_void_p,
c_void_p, c_void_p,
c_void_p, c_void_p,
c_void_p,
] ]
lib.infiniopDestroyRoPEDescriptor.restype = c_int32 lib.infiniopDestroyRoPEDescriptor.restype = c_int32
...@@ -387,42 +390,6 @@ def rope_(lib): ...@@ -387,42 +390,6 @@ def rope_(lib):
] ]
@OpRegister.operator
def rope_v2_(lib):
lib.infiniopCreateRoPEv2Descriptor.restype = c_int32
lib.infiniopCreateRoPEv2Descriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetRoPEv2WorkspaceSize.restype = c_int32
lib.infiniopGetRoPEv2WorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopRoPEv2.restype = c_int32
lib.infiniopRoPEv2.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyRoPEv2Descriptor.restype = c_int32
lib.infiniopDestroyRoPEv2Descriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator @OpRegister.operator
def sub_(lib): def sub_(lib):
lib.infiniopCreateSubDescriptor.restype = c_int32 lib.infiniopCreateSubDescriptor.restype = c_int32
......
...@@ -51,15 +51,27 @@ class Inplace(Enum): ...@@ -51,15 +51,27 @@ class Inplace(Enum):
INPLACE_X = auto() INPLACE_X = auto()
class Algorithm(Enum):
GPT_J = 0
GPT_NEOX = 1
_INPLACE = [ _INPLACE = [
Inplace.OUT_OF_PLACE, Inplace.OUT_OF_PLACE,
Inplace.INPLACE_X, Inplace.INPLACE_X,
] ]
_ALGO = [
Algorithm.GPT_J,
Algorithm.GPT_NEOX,
]
_TEST_CASES = [ _TEST_CASES = [
test_case + (inplace_item,) test_case + (inplace_item, algo_item)
for test_case in _TEST_CASES_ for test_case in _TEST_CASES_
for inplace_item in _INPLACE for inplace_item in _INPLACE
for algo_item in _ALGO
] ]
DEBUG = False DEBUG = False
...@@ -68,27 +80,45 @@ NUM_PRERUN = 10 ...@@ -68,27 +80,45 @@ NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
def rotary_embedding(ans, t, sin, cos, device): def rotary_embedding(ans, t, sin, cos, device, algo):
dh = t.shape[2] def _torch_rope(sin, cos, t1, t2):
cos = cos.unsqueeze(1) # [seq_len, 1, dh // 2]
sin = sin.unsqueeze(1) # [seq_len, 1, dh // 2]
if device == InfiniDeviceEnum.CPU:
(t1, t2, cos, sin) = (
t1.float(),
t2.float(),
cos.float(),
sin.float(),
)
t_out_1 = t1 * cos - t2 * sin
t_out_2 = t1 * sin + t2 * cos
return t_out_1, t_out_2
dh = t.shape[-1]
dt = t.dtype dt = t.dtype
assert dh % 2 == 0, "Embedding dimension must be even." assert dh % 2 == 0, "Embedding dimension must be even."
t_even = t[..., 0::2] # [seq_len, n_head, dh // 2]
t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
cos = cos.unsqueeze(1) # [seq_len, 1, dh // 2]
sin = sin.unsqueeze(1) # [seq_len, 1, dh // 2]
if device == InfiniDeviceEnum.CPU:
(t_even, t_odd, cos, sin) = (
t_even.float(),
t_odd.float(),
cos.float(),
sin.float(),
)
t_out_even = t_even * cos - t_odd * sin if algo == Algorithm.GPT_J:
t_out_odd = t_even * sin + t_odd * cos t_even = t[..., 0::2] # [seq_len, n_head, dh // 2]
t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
t_out_even, t_out_odd = _torch_rope(sin, cos, t_even, t_odd)
ans[..., 0::2] = t_out_even.to(dt)
ans[..., 1::2] = t_out_odd.to(dt)
else:
half_dim = dh // 2
t_first = t[..., :half_dim]
t_second = t[..., half_dim:]
t_out_first, t_out_second = _torch_rope(sin, cos, t_first, t_second)
ans[..., 0::2] = t_out_even.to(dt) ans[..., :half_dim] = t_out_first.to(dt)
ans[..., 1::2] = t_out_odd.to(dt) ans[..., half_dim:] = t_out_second.to(dt)
def sin_cos_table(pos, dim, device, theta, dtype): def sin_cos_table(pos, dim, device, theta, dtype):
...@@ -108,6 +138,7 @@ def test( ...@@ -108,6 +138,7 @@ def test(
x_strides=None, x_strides=None,
y_strides=None, y_strides=None,
inplace=Inplace.OUT_OF_PLACE, inplace=Inplace.OUT_OF_PLACE,
algo=Algorithm.GPT_J,
dtype=torch.float32, dtype=torch.float32,
sync=None, sync=None,
): ):
...@@ -120,7 +151,7 @@ def test( ...@@ -120,7 +151,7 @@ def test(
y = TestTensor(shape, y_strides, dtype, device) y = TestTensor(shape, y_strides, dtype, device)
print( print(
f"Testing Rotary Positional Embedding on {InfiniDeviceNames[device]} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}" f"Testing Rotary Positional Embedding on {InfiniDeviceNames[device]} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{InfiniDtypeNames[dtype]} inplace:{inplace} algo:{algo}"
) )
theta = 1e5 theta = 1e5
pos = TestTensor.from_torch(torch.arange(0, x.shape[0]), InfiniDtype.I32, device) pos = TestTensor.from_torch(torch.arange(0, x.shape[0]), InfiniDtype.I32, device)
...@@ -134,6 +165,7 @@ def test( ...@@ -134,6 +165,7 @@ def test(
sin_table.torch_tensor(), sin_table.torch_tensor(),
cos_table.torch_tensor(), cos_table.torch_tensor(),
device, device,
algo,
) )
descriptor = infiniopOperatorDescriptor_t() descriptor = infiniopOperatorDescriptor_t()
...@@ -150,6 +182,7 @@ def test( ...@@ -150,6 +182,7 @@ def test(
pos.descriptor, pos.descriptor,
sin_table.descriptor, sin_table.descriptor,
cos_table.descriptor, cos_table.descriptor,
algo.value,
) )
) )
......
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceEnum,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
)
from enum import Enum, auto
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# (shape, x_strides, y_strides)
((1, 32, 128), None, None),
((10, 32, 64), None, None),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), (64, 64, 1), None),
((11, 33, 128), None, (8000, 200, 1)),
((3, 32, 128), (8000, 200, 1), (7000, 128, 1)),
]
# Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2},
InfiniDtype.F32: {"atol": 1e-4, "rtol": 1e-3},
}
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_X = auto()
_INPLACE = [
Inplace.OUT_OF_PLACE,
Inplace.INPLACE_X,
]
_TEST_CASES = [
test_case + (inplace_item,)
for test_case in _TEST_CASES_
for inplace_item in _INPLACE
]
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
def rotary_embedding(ans, t, sin, cos, device):
dh = t.shape[-1]
dt = t.dtype
assert dh % 2 == 0, "Embedding dimension must be even."
half_dim = dh // 2
t_first = t[..., :half_dim]
t_second = t[..., half_dim:]
cos = cos.unsqueeze(1) # [seq_len, 1, half_dim]
sin = sin.unsqueeze(1) # [seq_len, 1, half_dim]
if device == InfiniDeviceEnum.CPU:
t_first = t_first.float()
t_second = t_second.float()
cos = cos.float()
sin = sin.float()
t_out_first = t_first * cos - t_second * sin
t_out_second = t_first * sin + t_second * cos
ans[..., :half_dim] = t_out_first.to(dt)
ans[..., half_dim:] = t_out_second.to(dt)
def sin_cos_table(pos, dim, device, theta, dtype):
assert dim % 2 == 0, "Embedding dimension must be even."
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
angles = torch.outer(pos.cpu(), freqs)
return (
TestTensor.from_torch(torch.sin(angles), dtype, device),
TestTensor.from_torch(torch.cos(angles), dtype, device),
)
def test(
handle,
device,
shape,
x_strides=None,
y_strides=None,
inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float32,
sync=None,
):
x = TestTensor(shape, x_strides, dtype, device)
if inplace == Inplace.INPLACE_X:
if x_strides != y_strides:
return
y = x
else:
y = TestTensor(shape, y_strides, dtype, device)
print(
f"Testing Rotary Positional Embedding on {InfiniDeviceNames[device]} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}"
)
theta = 1e5
pos = TestTensor.from_torch(torch.arange(0, x.shape[0]), InfiniDtype.I32, device)
sin_table, cos_table = sin_cos_table(
pos.torch_tensor(), x.shape[2], x.device, theta, dtype
)
rotary_embedding(
y.torch_tensor(),
x.torch_tensor(),
sin_table.torch_tensor(),
cos_table.torch_tensor(),
device,
)
descriptor = infiniopOperatorDescriptor_t()
if sync is not None:
sync()
check_error(
LIBINFINIOP.infiniopCreateRoPEv2Descriptor(
handle,
ctypes.byref(descriptor),
y.descriptor,
x.descriptor,
pos.descriptor,
sin_table.descriptor,
cos_table.descriptor,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [y, x, pos, sin_table, cos_table]:
tensor.destroy_desc()
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetRoPEv2WorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, x.device)
def lib_rope_v2():
check_error(
LIBINFINIOP.infiniopRoPEv2(
descriptor,
workspace.data(),
workspace_size.value,
y.data(),
x.data(),
pos.data(),
sin_table.data(),
cos_table.data(),
None,
)
)
lib_rope_v2()
if sync is not None:
sync()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
if PROFILE:
profile_operation(
"PyTorch",
lambda: rotary_embedding(
y.torch_tensor(),
x.torch_tensor(),
sin_table.torch_tensor(),
cos_table.torch_tensor(),
device,
),
device,
NUM_PRERUN,
NUM_ITERATIONS,
)
profile_operation(
" lib", lambda: lib_rope_v2(), device, NUM_PRERUN, NUM_ITERATIONS
)
check_error(LIBINFINIOP.infiniopDestroyRoPEv2Descriptor(descriptor))
if __name__ == "__main__":
args = get_args()
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
# Execute tests
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment