Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
47895fae
Unverified
Commit
47895fae
authored
Aug 21, 2025
by
xgqdut2016
Committed by
GitHub
Aug 21, 2025
Browse files
issue/163: 昆仑芯平台rope算子重构
parent
e20c0000
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
199 additions
and
0 deletions
+199
-0
src/infiniop/ops/rope/kunlun/rope_kunlun.h
src/infiniop/ops/rope/kunlun/rope_kunlun.h
+8
-0
src/infiniop/ops/rope/kunlun/rope_kunlun.xpu
src/infiniop/ops/rope/kunlun/rope_kunlun.xpu
+176
-0
src/infiniop/ops/rope/operator.cc
src/infiniop/ops/rope/operator.cc
+15
-0
No files found.
src/infiniop/ops/rope/kunlun/rope_kunlun.h
0 → 100644
View file @
47895fae
#ifndef __INFINIOP_ROPE_KUNLUN_H__
#define __INFINIOP_ROPE_KUNLUN_H__
#include "../rope.h"
DESCRIPTOR
(
kunlun
)
#endif // __INFINIOP_ROPE_KUNLUN_H__
src/infiniop/ops/rope/kunlun/rope_kunlun.xpu
0 → 100644
View file @
47895fae
#ifndef __ROPE_KUNLUN_KERNEL_XPU__
#define __ROPE_KUNLUN_KERNEL_XPU__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include "rope_kunlun.h"
#include <memory>
template <typename T, typename Tindex>
__global__ void RoPEKernel(T *destination, const T *source,
const Tindex *pos_ids, const T *sin_table, const T *cos_table,
uint32_t seqlen, uint32_t nhead, uint32_t dhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
XPUStream stream){
//ndim = 3
uint32_t other_size = seqlen * nhead;
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
int remain = other_size % nthreads;
int step_easy = (other_size - remain) / nthreads;
int step_hard = step_easy + 1;
int step = (thread_id < remain ? step_hard : step_easy);
int ind_start = (thread_id < remain ? thread_id * step_hard : remain * step_hard + (thread_id - remain) * step_easy);
constexpr int buf_size = 256;
__local__ T x_local[buf_size];
__local__ T y_local[buf_size];
__local__ T sin_local[buf_size];
__local__ T cos_local[buf_size];
__local__ Tindex pos_local[1];
int remain_dhead = dhead % buf_size;
int repeat = (dhead - remain_dhead) / buf_size;
for(int i = ind_start; i < ind_start + step; i++){
int ind_i = i;
int ind_d = 0;
int ind_s = 0;
ind_d += (ind_i % nhead) * y_stride_nhead;
ind_s += (ind_i % nhead) * x_stride_nhead;
ind_i /= nhead;
ind_d += (ind_i % seqlen) * y_stride_seqlen;
ind_s += (ind_i % seqlen) * x_stride_seqlen;
GM2LM(pos_ids + (ind_i % seqlen), pos_local, 1 * sizeof(Tindex));
int index = static_cast<int>(pos_local[0]) * dhead / 2;
for(int r = 0; r < repeat + (remain_dhead > 0 ? 1 : 0); r++){
int read_len = (r < repeat ? buf_size : remain_dhead);
int dk = read_len / 2;
int start_d = ind_d + r * buf_size;
int start_s = ind_s + r * buf_size;
int sin_cos_index = index + r * buf_size / 2;
GM2LM(source + start_s, x_local, read_len * sizeof(T));
GM2LM(sin_table + sin_cos_index, sin_local, dk * sizeof(T));
GM2LM(cos_table + sin_cos_index, cos_local, dk * sizeof(T));
if constexpr (xpu_std::is_same<T, float>::value || xpu_std::is_same<T, half>::value){
for(int k = 0; k < dk; k++){
y_local[2 * k] = x_local[2 * k] * cos_local[k] - x_local[2 * k + 1] * sin_local[k];
y_local[2 * k + 1] = x_local[2 * k] * sin_local[k] + x_local[2 * k + 1] * cos_local[k];
}
}
else if(xpu_std::is_same<T, bfloat16_t>::value){
for(int k = 0; k < dk; k++){
float x_0 = __bfloat162float(x_local[2 * k]);
float x_1 = __bfloat162float(x_local[2 * k + 1]);
float sin_f = __bfloat162float(sin_local[k]);
float cos_f = __bfloat162float(cos_local[k]);
y_local[2 * k] = __float2bfloat16(x_0 * cos_f - x_1 * sin_f);
y_local[2 * k + 1] = __float2bfloat16(x_0 * sin_f + x_1 * cos_f);
}
}
mfence();
LM2GM(y_local, destination + start_d, read_len * sizeof(T));
}
}
}
template <typename T, typename Tindex>
void RoPE(void *destination, const void *source,
const void *pos_ids, const void *sin_table, const void *cos_table,
uint32_t seqlen, uint32_t nhead, uint32_t dhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
XPUStream stream){
RoPEKernel<T, Tindex><<<8, 64, stream>>>((T *)destination, (T *)source,
(Tindex *)pos_ids, (T *)sin_table, (T *)cos_table,
seqlen, nhead, dhead,
x_stride_seqlen, x_stride_nhead,
y_stride_seqlen, y_stride_nhead, stream);
}
#define LAUNCH_KERNEL(T, Tindex) \
RoPE<T, Tindex>(y, x, pos_ids, sin_table, cos_table, \
seqlen, nhead, dhead, \
x_stride_seqlen, x_stride_nhead, \
y_stride_seqlen, y_stride_nhead, reinterpret_cast<kunlunStream_t>(stream));
namespace op::rope::kunlun {
struct Descriptor::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(result);
// Create descriptor
*desc_ptr = new Descriptor(
result.take(),
0,
new Descriptor::Opaque{static_cast<device::kunlun::Handle *>(handle)->internal()},
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
uint32_t seqlen = (uint32_t)_info.seqlen;
uint32_t nhead = (uint32_t)_info.nhead;
uint32_t dhead = (uint32_t)_info.dhead;
int32_t x_stride_seqlen = (int32_t)_info.x_stride_seqlen;
int32_t x_stride_nhead = (int32_t)_info.x_stride_nhead;
int32_t y_stride_seqlen = (int32_t)_info.y_stride_seqlen;
int32_t y_stride_nhead = (int32_t)_info.y_stride_nhead;
if (_info.pos_type == INFINI_DTYPE_I32) {
switch (_info.data_type) {
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, int32_t);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, int32_t);
return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, int32_t);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace op::rope::kunlun
#endif
src/infiniop/ops/rope/operator.cc
View file @
47895fae
...
...
@@ -17,6 +17,9 @@
#ifdef ENABLE_METAX_API
#include "metax/rope_metax.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/rope_kunlun.h"
#endif
__C
infiniStatus_t
infiniopCreateRoPEDescriptor
(
infiniopHandle_t
handle
,
...
...
@@ -54,6 +57,9 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
#ifdef ENABLE_ASCEND_API
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
...
...
@@ -91,6 +97,9 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_CAMBRICON_API
GET
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
...
...
@@ -138,6 +147,9 @@ __C infiniStatus_t infiniopRoPE(
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
...
...
@@ -178,6 +190,9 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_CAMBRICON_API
DELETE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment