Unverified Commit 9ad23fad authored by blkmjsian's avatar blkmjsian Committed by GitHub
Browse files

[T2-2-3] blkmjsian

- dequantize awq
- rope v2
parent b3170335
......@@ -7,6 +7,7 @@
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/clip.h"
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/mul.h"
#include "infiniop/ops/random_sample.h"
......@@ -14,8 +15,10 @@
#include "infiniop/ops/relu.h"
#include "infiniop/ops/rms_norm.h"
#include "infiniop/ops/rope.h"
#include "infiniop/ops/rope_v2.h"
#include "infiniop/ops/sub.h"
#include "infiniop/ops/swiglu.h"
#include "infiniop/ops/topkrouter.h"
#include "infiniop/tensor_descriptor.h"
#endif // __INFINIOP_API_H__
#ifndef __INFINIOP_DEQUANTIZE_API_H__
#define __INFINIOP_DEQUANTIZE_API_H__
#include "../operator_descriptor.h"
typedef struct InfiniopDescriptor *infiniopDequantizeDescriptor_t;
__C __export infiniStatus_t infiniopCreateDequantizeDescriptor(infiniopHandle_t handle,
infiniopDequantizeDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t qweight_desc,
infiniopTensorDescriptor_t scales_desc,
infiniopTensorDescriptor_t zeros_desc);
__C __export infiniStatus_t infiniopGetDequantizeWorkspaceSize(infiniopDequantizeDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopDequantize(infiniopDequantizeDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *qweight,
const void *scales,
const void *zeros,
size_t split_k_iters,
size_t thx,
size_t thy,
void *stream);
__C __export infiniStatus_t infiniopDestroyDequantizeDescriptor(infiniopDequantizeDescriptor_t desc);
#endif
#ifndef __INFINIOP_ROPE_V2_API_H__
#define __INFINIOP_ROPE_V2_API_H__
#include "../operator_descriptor.h"
typedef struct InfiniopDescriptor *infiniopRoPEv2Descriptor_t;
__C __export infiniStatus_t infiniopCreateRoPEv2Descriptor(
infiniopHandle_t handle,
infiniopRoPEv2Descriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table);
__C __export infiniStatus_t infiniopGetRoPEv2WorkspaceSize(infiniopRoPEv2Descriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopRoPEv2(
infiniopRoPEv2Descriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
void const *pos_ids,
void const *sin_table,
void const *cos_table,
void *stream);
__C __export infiniStatus_t infiniopDestroyRoPEv2Descriptor(infiniopRoPEv2Descriptor_t desc);
#endif
#ifndef __INFINIOP_TOPKRouter_API_H__
#define __INFINIOP_TOPKRouter_API_H__
#include "../operator_descriptor.h"
typedef struct InfiniopDescriptor *infiniopTopkrouterDescriptor_t;
__C __export infiniStatus_t infiniopCreateTopkrouterDescriptor(
infiniopHandle_t handle,
infiniopTopkrouterDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t correction_bias_desc);
__C __export infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void *workspace, size_t workspace_size,
void *values, void *indices, void *x, void *correction_bias, float routed_scaling_factor, size_t topk, void *stream);
__C __export infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescriptor_t desc);
#endif
#ifndef __DEQUANTIZE_H__
#define __DEQUANTIZE_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::dequantize::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
DequantizeInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
size_t workspace_size_, \
Opaque *opaque, \
DequantizeInfo info, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size_) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t qweight_desc, \
infiniopTensorDescriptor_t scales_desc, \
infiniopTensorDescriptor_t zeros_desc); \
\
infiniStatus_t calculate( \
void *workspace, \
size_t workspace_size, \
void *out, \
const void *qweight, \
const void *scales, \
const void *zeros, \
int split_k_iters, \
int thx, \
int thy, \
void *stream) const; \
}; \
}
#endif
#ifndef __DEQUANTIZE_INFO_H__
#define __DEQUANTIZE_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::dequantize {
class DequantizeInfo {
DequantizeInfo() = default;
public:
int _in_c, _qout_c, _G;
int in_c() const { return _in_c; }
int qout_c() const { return _qout_c; }
int G() const { return _G; }
static utils::Result<DequantizeInfo> create(
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t qweight_desc,
infiniopTensorDescriptor_t scales_desc,
infiniopTensorDescriptor_t zeros_desc) {
int _in_c = qweight_desc->dim(0);
int _qout_c = qweight_desc->dim(1);
int _G = scales_desc->dim(0);
return utils::Result<DequantizeInfo>(DequantizeInfo{
_in_c,
_qout_c,
_G});
}
};
} // namespace op::dequantize
#endif // __DEQUANTIZE_INFO_H__
#pragma once
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const &source) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
assert(false);
#else
uint4 result;
uint32_t *h = reinterpret_cast<uint32_t *>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is
// thanks to the register packing format and the fact that we force our
// integers to be unsigned, and account for this in the fp16 subtractions. In
// addition, I exploit the fact that sub and fma have the same throughput in
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
// the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit
// float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[0])
: "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(h[1])
: "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[2])
: "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(h[3])
: "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result;
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}
\ No newline at end of file
#include "../../../devices/nvidia/nvidia_handle.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "dequantize_w42f16_kernel.cuh"
#include "dequantize_w42f16_nvidia.cuh"
#include "../dequantize.h"
#include <cuda_fp16.h>
__global__ void __launch_bounds__(64)
dequantize_weights(int *__restrict__ B, half *__restrict__ scaling_factors,
int *__restrict__ zeros, half *__restrict__ C, int G) {
static constexpr uint32_t ZERO = 0x0;
half B_shared[32 * (128 + 8)];
half *B_shared_ptr2 = B_shared;
int N = blockDim.x * gridDim.x; // 2
int col = (blockIdx.x * blockDim.x + threadIdx.x);
int row = blockIdx.y * blockDim.y + threadIdx.y;
int index1 = 8 * col + 8 * row * N;
half *C_ptr2 = C + index1;
int index2 = col + row * N;
int *B_ptr2 = B + index2;
int index3 = col + (int)(row / G) * N;
int *zeros_ptr2 = zeros + index3;
int index4 = 8 * col + (int)(row / G) * N * 8;
half *scaling_factors_ptr2 = scaling_factors + index4;
uint32_t zeros_loaded = *(uint32_t *)(zeros_ptr2);
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4 *)(scaling_factors_ptr2);
uint32_t B_loaded = *(uint32_t *)B_ptr2;
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.x)
: "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.x)
: "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.y)
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.y)
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.z)
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.z)
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(B_loaded_fp16.w)
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(B_loaded_fp16.w)
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
*(uint4 *)B_shared_ptr2 = B_loaded_fp16;
for (int i = 0; i < 8; ++i) {
*(C_ptr2 + i) = B_shared[i];
}
}
namespace op::dequantize::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t qweight_desc,
infiniopTensorDescriptor_t scales_desc,
infiniopTensorDescriptor_t zeros_desc) {
auto handle = reinterpret_cast<device::nvidia::Handle *>(handle_);
auto result = DequantizeInfo::create(out_desc, qweight_desc, scales_desc, zeros_desc);
*desc_ptr = new Descriptor(
0,
new Opaque{handle->internal()},
result.take(),
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t
Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *out,
const void *qweight,
const void *scales,
const void *zeros,
int split_k_iters,
int thx,
int thy,
void *stream) const {
int in_c = _info.in_c();
int qout_c = _info.qout_c();
int out_c = qout_c * 8;
int G = in_c / _info.G();
int x_thread = thx;
int y_thread = thy;
int x_blocks = 1;
int y_blocks = 1;
if (thx == 0) {
x_thread = qout_c;
}
if (thy == 0) {
y_thread = in_c;
}
if (thx == 0 && thy == 0) {
x_thread = 8;
y_thread = 8;
x_blocks = (int)(qout_c / 8);
y_blocks = (int)(in_c / 8);
}
half *out_ = reinterpret_cast<half *>(out);
int *qweight_ = const_cast<int *>(reinterpret_cast<const int *>(qweight));
half *scales_ = const_cast<half *>(reinterpret_cast<const half *>(scales));
int *zeros_ = const_cast<int *>(reinterpret_cast<const int *>(zeros));
dim3 num_blocks(x_blocks, y_blocks);
dim3 threads_per_block(x_thread, y_thread);
dequantize_weights<<<num_blocks, threads_per_block, 0, reinterpret_cast<cudaStream_t>(stream)>>>(
qweight_, scales_, zeros_, out_, G);
return INFINI_STATUS_SUCCESS;
}
} // namespace op::dequantize::nvidia
\ No newline at end of file
#ifndef __DEQUANTIZE_CUDA_CUH__
#define __DEQUANTIZE_CUDA_CUH__
#include "../dequantize.h"
DESCRIPTOR(nvidia)
#endif // __GEMM_CUDA_CUH__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/dequantize.h"
#ifdef ENABLE_NVIDIA_API
#include "nvidia/dequantize_w42f16_nvidia.cuh"
#endif
__C infiniStatus_t infiniopCreateDequantizeDescriptor(
infiniopHandle_t handle,
infiniopDequantizeDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t qweight_desc,
infiniopTensorDescriptor_t scales_desc,
infiniopTensorDescriptor_t zeros_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::dequantize::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::dequantize::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, \
qweight_desc, \
scales_desc, \
zeros_desc)
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetDequantizeWorkspaceSize(infiniopDequantizeDescriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::dequantize::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
}
__C infiniStatus_t infiniopDequantize(
infiniopDequantizeDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *qweight,
const void *scales,
const void *zeros,
size_t split_k_iters,
size_t thx,
size_t thy,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::dequantize::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, out, qweight, scales, zeros, split_k_iters, thx, thy, stream)
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__C infiniStatus_t
infiniopDestroyDequantizeDescriptor(infiniopDequantizeDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::dequantize::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DELETE
}
// #endif
\ No newline at end of file
......@@ -76,7 +76,7 @@ infiniStatus_t rmsnormHalfPrecision(const RMSNormInfo *info, T *y, const T *x, c
if constexpr (std::is_same<Tw, float>::value) {
float val = utils::cast<float>(x_ptr[k]) * w[k] * rms;
y_ptr[k] = utils::cast<T>(val);
} else if constexpr (std::is_same<Tw, T>::value) {
} else if constexpr (std::is_same<Tw, T>::value || std::is_same_v<Tw, fp16_t> || std::is_same_v<Tw, bf16_t>) {
float val = utils::cast<float>(x_ptr[k]) * utils::cast<float>(w[k]) * rms;
y_ptr[k] = utils::cast<T>(val);
} else {
......@@ -97,6 +97,8 @@ infiniStatus_t Descriptor::calculate(
CHECK_STATUS(rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)x, (const fp16_t *)w));
} else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)x, (const float *)w));
} else if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)x, (const bf16_t *)w));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......@@ -105,6 +107,8 @@ infiniStatus_t Descriptor::calculate(
CHECK_STATUS(rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)x, (const bf16_t *)w));
} else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)x, (const float *)w));
} else if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)x, (const fp16_t *)w));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -34,7 +34,7 @@ public:
}
if (atype == INFINI_DTYPE_F16 || atype == INFINI_DTYPE_BF16) {
// For half-precision types (FP16/BF16), weights can be the same half-precision type or FP32
if (wtype != atype && wtype != INFINI_DTYPE_F32) {
if (wtype != atype && wtype != INFINI_DTYPE_F32 && wtype != INFINI_DTYPE_BF16 && wtype != INFINI_DTYPE_F16) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
......
......@@ -77,10 +77,14 @@ infiniStatus_t launchKernel(
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __nv_bfloat16, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__nv_bfloat16, __nv_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__nv_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__nv_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
......
#include "rope_ascend.h"
#include "../../../devices/ascend/common_ascend.h"
namespace op::rope::ascend {
Descriptor::~Descriptor()
= default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle_ascned = reinterpret_cast<device::ascend::Handle *>(handle);
auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(result);
size_t workspace_size = 0;
*desc_ptr = new Descriptor(std::move(result.take()), workspace_size, nullptr, handle_ascned->device, handle_ascned->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
CHECK_DTYPE(_info.data_type, INFINI_DTYPE_F32, INFINI_DTYPE_F16);
auto data_type = _info.data_type;
auto pos_type = _info.pos_type;
auto seq_len = _info.seqlen;
auto nhead = _info.nhead;
auto dhead = _info.dhead;
auto y_stride_seqlen = _info.y_stride_seqlen;
auto y_stride_nhead = _info.y_stride_nhead;
auto x_stride_seqlen = _info.x_stride_seqlen;
auto x_stride_nhead = _info.x_stride_nhead;
return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, pos_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream);
}
} // namespace op::rope::ascend
#ifndef __ACLNN_ROPE_H__
#define __ACLNN_ROPE_H__
#include "../rope.h"
extern "C" infiniStatus_t rope_kernel_launch(
void *y,
void *x,
void *pos,
void *sin,
void *cos,
size_t seq_len,
size_t nhead,
size_t dhead,
infiniDtype_t data_type,
infiniDtype_t pos_type,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead,
void *stream);
DESCRIPTOR(ascend)
#endif // __ACLNN_ROPE_H__
#include "../../../devices/ascend/ascend_kernel_common.h"
using namespace AscendC;
template <typename T, typename U>
class RoPEKernel {
public:
__aicore__ inline RoPEKernel() {}
// Init op
// pos position vector
// x input tensor
// y output tensor
// tensor shape [nt, nh, dh]
// make block_num = nh, tile_len = dh
__aicore__ inline void init(GM_ADDR y,
GM_ADDR x,
GM_ADDR pos,
GM_ADDR sin,
GM_ADDR cos,
size_t dh,
ptrdiff_t st_ynt,
ptrdiff_t st_ynh,
ptrdiff_t st_xnt,
ptrdiff_t st_xnh);
__aicore__ inline void process(size_t seq_len);
private:
// Copy a tile into UB
__aicore__ inline void copyIn(size_t i);
__aicore__ inline void compute(size_t i);
__aicore__ inline void copyOut(size_t i);
private:
TPipe pipe;
TQue<QuePosition::VECIN, BUFFER_NUM> _in_que;
TQue<QuePosition::VECIN, BUFFER_NUM> _sin_que;
TQue<QuePosition::VECIN, BUFFER_NUM> _cos_que;
TQue<QuePosition::VECOUT, BUFFER_NUM> _out_que;
TBuf<TPosition::VECCALC> _tmp_odd_buf;
TBuf<TPosition::VECCALC> _tmp_even_buf;
TBuf<TPosition::VECCALC> _tmp_odd_buf1;
TBuf<TPosition::VECCALC> _tmp_odd_buf2;
TBuf<TPosition::VECCALC> _tmp_even_buf1;
TBuf<TPosition::VECCALC> _tmp_even_buf2;
GlobalTensor<T> _x_gm, _y_gm;
GlobalTensor<U> _p_gm;
GlobalTensor<T> _sin_gm;
GlobalTensor<T> _cos_gm;
size_t _block_idx;
size_t _tile_len;
size_t _copy_len;
size_t _half_copy_len;
// stridey[_st_ynt, _st_ynh, 1]
ptrdiff_t _st_ynt;
ptrdiff_t _st_ynh;
// stridex[_st_xnt, _st_xnh, 1]
ptrdiff_t _st_xnt;
ptrdiff_t _st_xnh;
};
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::init(GM_ADDR y,
GM_ADDR x,
GM_ADDR pos,
GM_ADDR sin,
GM_ADDR cos,
size_t dh,
ptrdiff_t st_ynt,
ptrdiff_t st_ynh,
ptrdiff_t st_xnt,
ptrdiff_t st_xnh) {
this->_tile_len = dh;
this->_st_ynt = st_ynt;
this->_st_ynh = st_ynh;
this->_st_xnt = st_xnt;
this->_st_xnh = st_xnh;
_copy_len = alignTileLen<T>(dh, BYTE_ALIGN);
_half_copy_len = alignTileLen<T>(dh, BYTE_ALIGN);
_block_idx = GetBlockIdx();
// Init global buffer
_x_gm.SetGlobalBuffer((__gm__ T *)x);
_p_gm.SetGlobalBuffer((__gm__ U *)pos);
_sin_gm.SetGlobalBuffer((__gm__ T *)sin);
_cos_gm.SetGlobalBuffer((__gm__ T *)cos);
_y_gm.SetGlobalBuffer((__gm__ T *)y);
// Init Queue buffer
pipe.InitBuffer(_in_que, BUFFER_NUM, _copy_len * sizeof(T));
pipe.InitBuffer(_out_que, BUFFER_NUM, _tile_len * sizeof(T));
pipe.InitBuffer(_sin_que, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(_cos_que, BUFFER_NUM, _half_copy_len * sizeof(T));
pipe.InitBuffer(_tmp_odd_buf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_even_buf, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_odd_buf1, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_odd_buf2, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_even_buf1, _tile_len / 2 * sizeof(T));
pipe.InitBuffer(_tmp_even_buf2, _tile_len / 2 * sizeof(T));
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::copyIn(size_t i) {
LocalTensor<T> input_ub = _in_que.AllocTensor<T>();
LocalTensor<T> sin_ub = _sin_que.AllocTensor<T>();
LocalTensor<T> cos_ub = _cos_que.AllocTensor<T>();
// Get idx of current tile in total input
auto idx = i * _st_xnt + _block_idx * _st_xnh;
// Copy tile current tile into UB
DataCopy(input_ub, _x_gm[idx], _copy_len);
// Copy sin cos tile
auto pos_idx = _p_gm(i);
DataCopy(sin_ub, _sin_gm[pos_idx * _tile_len / 2], _half_copy_len);
DataCopy(cos_ub, _cos_gm[pos_idx * _tile_len / 2], _half_copy_len);
// Push in operands
_in_que.EnQue(input_ub);
_sin_que.EnQue(sin_ub);
_cos_que.EnQue(cos_ub);
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::compute(size_t i) {
LocalTensor<T> input_ub = _in_que.DeQue<T>();
LocalTensor<T> sin_ub = _sin_que.DeQue<T>();
LocalTensor<T> cos_ub = _cos_que.DeQue<T>();
LocalTensor<T> output_ub = _out_que.AllocTensor<T>();
LocalTensor<T> tmp_odd = _tmp_odd_buf.Get<T>();
LocalTensor<T> tmp_even = _tmp_even_buf.Get<T>();
LocalTensor<T> tmp_odd1 = _tmp_odd_buf1.Get<T>();
LocalTensor<T> tmp_odd2 = _tmp_odd_buf2.Get<T>();
LocalTensor<T> tmp_even1 = _tmp_even_buf1.Get<T>();
LocalTensor<T> tmp_even2 = _tmp_even_buf2.Get<T>();
// separate odd and even bit elements
uint64_t rsvdCnt = 0;
GatherMaskParams gMaskParams = {
1,
static_cast<uint16_t>((_tile_len * sizeof(T) + 255) / 256), // no more than 256(<=255)
8,
8,
};
GatherMask<T>(tmp_odd, input_ub, 1, false, 0, gMaskParams, rsvdCnt);
GatherMask<T>(tmp_even, input_ub, 2, false, 0, gMaskParams, rsvdCnt);
PipeBarrier<PIPE_V>();
// compute odd bit elements
// y_odd = x_odd * cos - x_even * sin
Mul<T>(tmp_odd1, tmp_odd, cos_ub, _tile_len / 2);
Mul<T>(tmp_odd2, tmp_even, sin_ub, _tile_len / 2);
PipeBarrier<PIPE_V>();
Sub<T>(tmp_odd1, tmp_odd1, tmp_odd2, _tile_len / 2);
// compute even bit elements
// y_even = x_odd * sin + x_even * cos
Mul<T>(tmp_even1, tmp_odd, sin_ub, _tile_len / 2);
Mul<T>(tmp_even2, tmp_even, cos_ub, _tile_len / 2);
PipeBarrier<PIPE_V>();
Add<T>(tmp_even1, tmp_even1, tmp_even2, _tile_len / 2);
// combine odd and even bit elements
for (uint32_t j = 0; j < _tile_len / 2; j += 1) {
output_ub(j * 2) = tmp_odd1(j);
output_ub(j * 2 + 1) = tmp_even1(j);
}
_out_que.EnQue<T>(output_ub);
_in_que.FreeTensor(input_ub);
_sin_que.FreeTensor(sin_ub);
_cos_que.FreeTensor(cos_ub);
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::copyOut(size_t i) {
LocalTensor<T> output_ub = _out_que.DeQue<T>();
auto idy = i * _st_ynt + _block_idx * _st_ynh;
DataCopyExtParams params = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
DataCopyPad(_y_gm[idy], output_ub, params);
_out_que.FreeTensor(output_ub);
}
template <typename T, typename U>
__aicore__ inline void RoPEKernel<T, U>::process(size_t seq_len) {
for (size_t i = 0; i < seq_len; ++i) {
copyIn(i);
compute(i);
copyOut(i);
}
}
#define ROPE_KERNEL_INIT_ARGS y, x, pos, sin, cos, dhead, \
y_stride_seqlen, y_stride_nhead, \
x_stride_seqlen, x_stride_nhead
#define CASE_POSTYPE(POS_TYPE_ENUM, TYPE, POS_T) \
case POS_TYPE_ENUM: { \
RoPEKernel<TYPE, POS_T> op; \
op.init(ROPE_KERNEL_INIT_ARGS); \
op.process(seq_len); \
break; \
}
#define ROPE_KERNEL(TYPE, POSTYPE) \
switch (POSTYPE) { \
CASE_POSTYPE(INFINI_DTYPE_I8, TYPE, int8_t) \
CASE_POSTYPE(INFINI_DTYPE_I16, TYPE, int16_t) \
CASE_POSTYPE(INFINI_DTYPE_I32, TYPE, int32_t) \
CASE_POSTYPE(INFINI_DTYPE_I64, TYPE, int64_t) \
CASE_POSTYPE(INFINI_DTYPE_U8, TYPE, uint8_t) \
CASE_POSTYPE(INFINI_DTYPE_U16, TYPE, uint16_t) \
CASE_POSTYPE(INFINI_DTYPE_U32, TYPE, uint32_t) \
CASE_POSTYPE(INFINI_DTYPE_U64, TYPE, uint64_t) \
default: \
break; \
}
#define DEFINE_ROPE_KERNEL(KERNEL_NAME, TYPE) \
__global__ __aicore__ void KERNEL_NAME(GM_ADDR y, \
GM_ADDR x, \
GM_ADDR pos, \
GM_ADDR sin, \
GM_ADDR cos, \
size_t seq_len, \
size_t dhead, \
ptrdiff_t y_stride_seqlen, \
ptrdiff_t y_stride_nhead, \
ptrdiff_t x_stride_seqlen, \
ptrdiff_t x_stride_nhead, \
int32_t pos_type) { \
ROPE_KERNEL(TYPE, pos_type) \
}
DEFINE_ROPE_KERNEL(rope_kernel_float, float)
DEFINE_ROPE_KERNEL(rope_kernel_half, half)
#undef DEFINE_ROPE_KERNEL
#undef ROPE_KERNEL
#undef CASE_POSTYPE
#undef ROPE_KERNEL_INIT_ARGS
extern "C" infiniStatus_t rope_kernel_launch(
void *y,
void *x,
void *pos,
void *sin,
void *cos,
size_t seq_len,
size_t nhead,
size_t dhead,
infiniDtype_t dtype,
infiniDtype_t pos_type,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead,
void *stream) {
#define LAUNCH_ROPE_KERNEL(DTYPE_ENUM, KERNEL_NAME) \
case DTYPE_ENUM: \
KERNEL_NAME<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, \
seq_len, \
dhead, \
y_stride_seqlen, \
y_stride_nhead, \
x_stride_seqlen, \
x_stride_nhead, \
pos_type); \
return INFINI_STATUS_SUCCESS;
switch (dtype) {
LAUNCH_ROPE_KERNEL(INFINI_DTYPE_F16, rope_kernel_half)
LAUNCH_ROPE_KERNEL(INFINI_DTYPE_F32, rope_kernel_float)
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
#ifndef __INFINIOP_ROPE_BANG_H__
#define __INFINIOP_ROPE_BANG_H__
#include "../rope.h"
DESCRIPTOR(bang)
#endif // __INFINIOP_ROPE_BANG_H__
#include "../../../devices/bang/common_bang.h"
#include "rope_bang.h"
#include "rope_bang_kernel.mlu"
namespace op::rope::bang {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::bang::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
nullptr,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPE(const RoPEInfo &info,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
cnrtQueue_t queue) {
auto dimx = uint32_t(info.seqlen);
auto dimy = uint32_t(info.nhead);
auto table_dim = uint32_t(info.table_dim);
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
// Configure kernel launch parameters
k_dim.x = 4;
k_dim.y = 1;
k_dim.z = 1;
k_type = CNRT_FUNC_TYPE_UNION1;
// Launch kernel
ropeKernel<<<k_dim, k_type, queue>>>(
y, x, pos_ids, sin_table, cos_table,
dimx, dimy, table_dim,
info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_seqlen, info.x_stride_nhead);
cnrtQueueSync(queue);
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE(TDATA, TINDEX) \
calculateRoPE(_info, \
(TDATA *)y, \
(const TDATA *)x, \
(const TINDEX *)pos_ids, \
(const TDATA *)sin_table, \
(const TDATA *)cos_table, \
(cnrtQueue_t)stream)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(half);
case INFINI_DTYPE_BF16:
ROPE_TYPE(bfloat16_t);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope::bang
#include "../../../devices/bang/common_bang.h"
__nram__ char nram_buffer[NRAM_MAX_SIZE];
template <typename Tdata>
__mlu_device__ void calculateRope(
Tdata *out, const Tdata *in,
const Tdata *sin_table, const Tdata *cos_table,
Tdata *sin_cache, Tdata *cos_cache,
Tdata *x1sin, Tdata *x0cos, Tdata *x0sin, Tdata *x1cos,
Tdata *input_0, Tdata *input_1, Tdata *input_cache,
int theta_index, int out_index, int in_index,
int chunk_size, int half_chunk_size, int data_segsize,
int src_load_stride, int dst_load_stride, int src_write_stride, int dst_write_stride) {
// Load sin/cos data
__memcpy(sin_cache, sin_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
__memcpy(cos_cache, cos_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
// Load input data
__memcpy(input_cache, in + in_index, chunk_size * sizeof(Tdata), GDRAM2NRAM);
// Split input into even and odd positions
__memcpy(input_0, input_cache, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
__memcpy(input_1, input_cache + 1, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
// Compute even positions: y0 = x0 * cos - x1 * sin and y1 = x0 * sin + x1 * cos
__bang_mul(x0cos, input_0, cos_cache, half_chunk_size);
__bang_mul(x1sin, input_1, sin_cache, half_chunk_size);
__bang_mul(x0sin, input_0, sin_cache, half_chunk_size);
__bang_mul(x1cos, input_1, cos_cache, half_chunk_size);
__bang_sub(input_0, x0cos, x1sin, half_chunk_size);
__bang_add(input_1, x0sin, x1cos, half_chunk_size);
// Interleave results back into output buffer
__memcpy(input_cache, input_0, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
__memcpy(input_cache + 1, input_1, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
// Write back results
__memcpy(out + out_index, input_cache, chunk_size * sizeof(Tdata), NRAM2GDRAM);
}
template <typename Tdata, typename Tindex>
__mlu_global__ void ropeKernel(
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
uint32_t seqlen,
uint32_t nhead,
uint32_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
// Calculate available NRAM space after alignment
const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * 9); // 9 buffers need alignment
const size_t max_chunk_elements = nram_usable / (9 * sizeof(Tdata));
// Key variables that determine execution path
const bool use_pos_ids_buffer = (seqlen * sizeof(Tindex) <= (nram_usable / 2));
const int half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
// Common stride configurations
const int data_segsize = sizeof(Tdata);
const int src_load_stride = 2 * sizeof(Tdata);
const int dst_load_stride = 1 * sizeof(Tdata);
const int src_write_stride = 1 * sizeof(Tdata);
const int dst_write_stride = 2 * sizeof(Tdata);
// Task distribution
const int batch_volume = seqlen * nhead;
const int remaining_tasks = batch_volume % taskDim;
const int base_tasks_per_core = batch_volume / taskDim;
const int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
const int task_start_idx = (taskId < remaining_tasks ? taskId * base_tasks_per_core + taskId : taskId * base_tasks_per_core + remaining_tasks);
// NRAM buffer allocation with proper alignment
char *aligned_nram = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
// Setup position IDs if they fit in NRAM
Tindex *srcP = nullptr;
if (use_pos_ids_buffer) {
srcP = (Tindex *)aligned_nram;
__memcpy(srcP, pos_ids, seqlen * sizeof(Tindex), GDRAM2NRAM);
aligned_nram = (char *)(((size_t)srcP + seqlen * sizeof(Tindex) + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
}
// Main processing buffers (pointers will be set per chunk)
Tdata *sin_cache = nullptr;
Tdata *cos_cache = nullptr;
Tdata *x1sin = nullptr;
Tdata *x0cos = nullptr;
Tdata *x0sin = nullptr;
Tdata *x1cos = nullptr;
Tdata *input_0 = nullptr;
Tdata *input_1 = nullptr;
Tdata *input_cache = nullptr;
// Main processing loop
for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
// Calculate output and input indices
int seq_idx = i / nhead;
int head_idx = i % nhead;
// Output indices (y)
int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
// Input indices (x)
int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
// Get position index
Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
int rot_offset = pos_idx * table_dim;
// Process in chunks that fit in NRAM
int processed = 0;
while (processed < table_dim) {
// Calculate current chunk size
int current_half_chunk = std::min<uint32_t>(half_chunk_size, table_dim - processed);
int current_chunk_size = 2 * current_half_chunk;
int theta_offset = rot_offset + processed;
int dst_offset = out_offset + processed * 2;
int src_offset = in_offset + processed * 2;
// Set up NRAM buffers for this chunk
char *chunk_base = aligned_nram;
sin_cache = (Tdata *)chunk_base;
cos_cache = sin_cache + current_half_chunk;
x1sin = cos_cache + current_half_chunk;
x0cos = x1sin + current_half_chunk;
x0sin = x0cos + current_half_chunk;
x1cos = x0sin + current_half_chunk;
input_0 = x1cos + current_half_chunk;
input_1 = input_0 + current_half_chunk;
input_cache = input_1 + current_half_chunk;
calculateRope<Tdata>(
y, x, sin_table, cos_table,
sin_cache, cos_cache, x1sin, x0cos, x0sin, x1cos,
input_0, input_1, input_cache,
theta_offset, dst_offset, src_offset,
current_chunk_size, current_half_chunk,
data_segsize,
src_load_stride, dst_load_stride, src_write_stride, dst_write_stride);
processed += current_half_chunk;
}
}
}
#include "rope_v2_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
namespace op::rope_v2::cpu {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto info = RoPEv2Info::createRoPEv2Info(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
nullptr,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPEv2(const RoPEv2Info &info,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table) {
#pragma omp parallel for
for (ptrdiff_t h = 0; h < ptrdiff_t(info.nhead); h++) {
for (size_t tok = 0; tok < info.seqlen; tok++) {
size_t x_offset = tok * info.x_stride_seqlen + h * info.x_stride_nhead;
size_t y_offset = tok * info.y_stride_seqlen + h * info.y_stride_nhead;
size_t pos_id = size_t(pos_ids[tok]);
size_t table_offset = pos_id * info.table_dim;
size_t half_dim = info.table_dim; // head_dim = 2 * half_dim
for (size_t i = 0; i < info.table_dim; i++) {
// Pair elements from first half and second half
size_t pos0 = i;
size_t pos1 = i + half_dim;
if constexpr (std::is_same<Tdata, fp16_t>::value || std::is_same<Tdata, bf16_t>::value) {
float x0 = utils::cast<float>(x[x_offset + pos0]),
x1 = utils::cast<float>(x[x_offset + pos1]),
sin__ = utils::cast<float>(sin_table[table_offset + i]),
cos__ = utils::cast<float>(cos_table[table_offset + i]);
y[y_offset + pos0] = utils::cast<Tdata>(x0 * cos__ - x1 * sin__);
y[y_offset + pos1] = utils::cast<Tdata>(x0 * sin__ + x1 * cos__);
} else {
Tdata x0 = x[x_offset + pos0],
x1 = x[x_offset + pos1],
sin__ = sin_table[table_offset + i],
cos__ = cos_table[table_offset + i];
y[y_offset + pos0] = x0 * cos__ - x1 * sin__;
y[y_offset + pos1] = x0 * sin__ + x1 * cos__;
}
}
}
}
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE_V2(TDATA, TINDEX) \
calculateRoPEv2(_info, (TDATA *)y, (const TDATA *)x, (const TINDEX *)pos_ids, (const TDATA *)sin_table, (const TDATA *)cos_table)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE_V2(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE_V2(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE_V2(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE_V2(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE_V2(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE_V2(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE_V2(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE_V2(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(fp16_t);
case INFINI_DTYPE_BF16:
ROPE_TYPE(bf16_t);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
ROPE_TYPE(double);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope_v2::cpu
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment