Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
c15189bf
Commit
c15189bf
authored
Sep 18, 2025
by
xgqdut2016
Browse files
issue/466: success kunlun rope NEOX
parent
6680a8c8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
32 deletions
+69
-32
src/infiniop/ops/rope/kunlun/rope_kunlun.xpu
src/infiniop/ops/rope/kunlun/rope_kunlun.xpu
+69
-32
No files found.
src/infiniop/ops/rope/kunlun/rope_kunlun.xpu
View file @
c15189bf
...
...
@@ -12,7 +12,7 @@ __global__ void RoPEKernel(T *destination, const T *source,
const Tindex *pos_ids, const T *sin_table, const T *cos_table,
uint32_t seqlen, uint32_t nhead, uint32_t dhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
bool IsGPTJ,
XPUStream stream) {
// ndim = 3
uint32_t other_size = seqlen * nhead;
...
...
@@ -41,6 +41,11 @@ __global__ void RoPEKernel(T *destination, const T *source,
int remain_dhead = dhead % buf_size;
int repeat = (dhead - remain_dhead) / buf_size;
int table_dim = dhead / 2;
constexpr int buf_table = buf_size / 2;
int remain_table = table_dim % buf_table;
int repeat_table = (table_dim - remain_table) / buf_table;
for (int i = ind_start; i < ind_start + step; i++) {
int ind_i = i;
int ind_d = 0;
...
...
@@ -51,7 +56,8 @@ __global__ void RoPEKernel(T *destination, const T *source,
ind_d += (ind_i % seqlen) * y_stride_seqlen;
ind_s += (ind_i % seqlen) * x_stride_seqlen;
GM2LM(pos_ids + (ind_i % seqlen), pos_local, 1 * sizeof(Tindex));
int index = static_cast<int>(pos_local[0]) * dhead / 2;
int index = static_cast<int>(pos_local[0]) * table_dim;
if (IsGPTJ){
for (int r = 0; r < repeat + (remain_dhead > 0 ? 1 : 0); r++) {
int read_len = (r < repeat ? buf_size : remain_dhead);
int dk = read_len / 2;
...
...
@@ -80,6 +86,40 @@ __global__ void RoPEKernel(T *destination, const T *source,
LM2GM(y_local, destination + start_d, read_len * sizeof(T));
}
}
else{
for (int r = 0; r < repeat_table + (remain_table > 0 ? 1 : 0); r++) {
int read_len = (r < repeat_table ? buf_table : remain_table);
int start_d_0 = ind_d + r * buf_table;
int start_s_0 = ind_s + r * buf_table;
int start_d_1 = ind_d + r * buf_table + table_dim;
int start_s_1 = ind_s + r * buf_table + table_dim;
int sin_cos_index = index + r * buf_table;
GM2LM(source + start_s_0, x_local, read_len * sizeof(T));
GM2LM(source + start_s_1, x_local + buf_table, read_len * sizeof(T));
GM2LM(sin_table + sin_cos_index, sin_local, read_len * sizeof(T));
GM2LM(cos_table + sin_cos_index, cos_local, read_len * sizeof(T));
if constexpr (xpu_std::is_same<T, float>::value || xpu_std::is_same<T, half>::value) {
for (int k = 0; k < read_len; k++) {
y_local[k] = x_local[k] * cos_local[k] - x_local[k + buf_table] * sin_local[k];
y_local[k + buf_table] = x_local[k] * sin_local[k] + x_local[k + buf_table] * cos_local[k];
}
} else if (xpu_std::is_same<T, bfloat16_t>::value) {
for (int k = 0; k < read_len; k++) {
float x_0 = __bfloat162float(x_local[k]);
float x_1 = __bfloat162float(x_local[k + buf_table]);
float sin_f = __bfloat162float(sin_local[k]);
float cos_f = __bfloat162float(cos_local[k]);
y_local[k] = __float2bfloat16(x_0 * cos_f - x_1 * sin_f);
y_local[k + buf_table] = __float2bfloat16(x_0 * sin_f + x_1 * cos_f);
}
}
mfence();
LM2GM(y_local, destination + start_d_0, read_len * sizeof(T));
LM2GM(y_local + buf_table, destination + start_d_1, read_len * sizeof(T));
}
}
}
}
template <typename T, typename Tindex>
...
...
@@ -87,19 +127,19 @@ void RoPE(void *destination, const void *source,
const void *pos_ids, const void *sin_table, const void *cos_table,
uint32_t seqlen, uint32_t nhead, uint32_t dhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead,
bool IsGPTJ,
XPUStream stream) {
RoPEKernel<T, Tindex><<<8, 64, stream>>>((T *)destination, (T *)source,
(Tindex *)pos_ids, (T *)sin_table, (T *)cos_table,
seqlen, nhead, dhead,
x_stride_seqlen, x_stride_nhead,
y_stride_seqlen, y_stride_nhead, stream);
y_stride_seqlen, y_stride_nhead,
IsGPTJ,
stream);
}
#define LAUNCH_KERNEL(T, Tindex) \
RoPE<T, Tindex>(y, x, pos_ids, sin_table, cos_table, \
seqlen, nhead, dhead, \
x_stride_seqlen, x_stride_nhead, \
y_stride_seqlen, y_stride_nhead, reinterpret_cast<kunlunStream_t>(stream));
y_stride_seqlen, y_stride_nhead,
IsGPTJ,
reinterpret_cast<kunlunStream_t>(stream));
namespace op::rope::kunlun {
...
...
@@ -124,10 +164,6 @@ infiniStatus_t Descriptor::create(
auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo);
CHECK_RESULT(result);
if (algo != INFINIOP_ROPE_ALGO_GPT_J) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
// Create descriptor
*desc_ptr = new Descriptor(
result.take(),
...
...
@@ -155,6 +191,7 @@ infiniStatus_t Descriptor::calculate(
int32_t x_stride_nhead = (int32_t)_info.x_stride_nhead;
int32_t y_stride_seqlen = (int32_t)_info.y_stride_seqlen;
int32_t y_stride_nhead = (int32_t)_info.y_stride_nhead;
bool IsGPTJ = _info.algo == infiniopRoPEAlgo_t::INFINIOP_ROPE_ALGO_GPT_J;
if (_info.pos_type == INFINI_DTYPE_I32) {
switch (_info.data_type) {
case INFINI_DTYPE_F32:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment