Unverified Commit 2a81c8bd authored by zhangyue's avatar zhangyue Committed by GitHub
Browse files

Merge pull request #467 from InfiniTensor/issue/466

issue/466: 昆仑平台rope关于NEOX算法的实现
parents d0b7bf92 c15189bf
...@@ -12,7 +12,7 @@ __global__ void RoPEKernel(T *destination, const T *source, ...@@ -12,7 +12,7 @@ __global__ void RoPEKernel(T *destination, const T *source,
const Tindex *pos_ids, const T *sin_table, const T *cos_table, const Tindex *pos_ids, const T *sin_table, const T *cos_table,
uint32_t seqlen, uint32_t nhead, uint32_t dhead, uint32_t seqlen, uint32_t nhead, uint32_t dhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead, int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead, int32_t y_stride_seqlen, int32_t y_stride_nhead, bool IsGPTJ,
XPUStream stream) { XPUStream stream) {
// ndim = 3 // ndim = 3
uint32_t other_size = seqlen * nhead; uint32_t other_size = seqlen * nhead;
...@@ -41,6 +41,11 @@ __global__ void RoPEKernel(T *destination, const T *source, ...@@ -41,6 +41,11 @@ __global__ void RoPEKernel(T *destination, const T *source,
int remain_dhead = dhead % buf_size; int remain_dhead = dhead % buf_size;
int repeat = (dhead - remain_dhead) / buf_size; int repeat = (dhead - remain_dhead) / buf_size;
int table_dim = dhead / 2;
constexpr int buf_table = buf_size / 2;
int remain_table = table_dim % buf_table;
int repeat_table = (table_dim - remain_table) / buf_table;
for (int i = ind_start; i < ind_start + step; i++) { for (int i = ind_start; i < ind_start + step; i++) {
int ind_i = i; int ind_i = i;
int ind_d = 0; int ind_d = 0;
...@@ -51,7 +56,8 @@ __global__ void RoPEKernel(T *destination, const T *source, ...@@ -51,7 +56,8 @@ __global__ void RoPEKernel(T *destination, const T *source,
ind_d += (ind_i % seqlen) * y_stride_seqlen; ind_d += (ind_i % seqlen) * y_stride_seqlen;
ind_s += (ind_i % seqlen) * x_stride_seqlen; ind_s += (ind_i % seqlen) * x_stride_seqlen;
GM2LM(pos_ids + (ind_i % seqlen), pos_local, 1 * sizeof(Tindex)); GM2LM(pos_ids + (ind_i % seqlen), pos_local, 1 * sizeof(Tindex));
int index = static_cast<int>(pos_local[0]) * dhead / 2; int index = static_cast<int>(pos_local[0]) * table_dim;
if (IsGPTJ){
for (int r = 0; r < repeat + (remain_dhead > 0 ? 1 : 0); r++) { for (int r = 0; r < repeat + (remain_dhead > 0 ? 1 : 0); r++) {
int read_len = (r < repeat ? buf_size : remain_dhead); int read_len = (r < repeat ? buf_size : remain_dhead);
int dk = read_len / 2; int dk = read_len / 2;
...@@ -80,6 +86,40 @@ __global__ void RoPEKernel(T *destination, const T *source, ...@@ -80,6 +86,40 @@ __global__ void RoPEKernel(T *destination, const T *source,
LM2GM(y_local, destination + start_d, read_len * sizeof(T)); LM2GM(y_local, destination + start_d, read_len * sizeof(T));
} }
} }
else{
for (int r = 0; r < repeat_table + (remain_table > 0 ? 1 : 0); r++) {
int read_len = (r < repeat_table ? buf_table : remain_table);
int start_d_0 = ind_d + r * buf_table;
int start_s_0 = ind_s + r * buf_table;
int start_d_1 = ind_d + r * buf_table + table_dim;
int start_s_1 = ind_s + r * buf_table + table_dim;
int sin_cos_index = index + r * buf_table;
GM2LM(source + start_s_0, x_local, read_len * sizeof(T));
GM2LM(source + start_s_1, x_local + buf_table, read_len * sizeof(T));
GM2LM(sin_table + sin_cos_index, sin_local, read_len * sizeof(T));
GM2LM(cos_table + sin_cos_index, cos_local, read_len * sizeof(T));
if constexpr (xpu_std::is_same<T, float>::value || xpu_std::is_same<T, half>::value) {
for (int k = 0; k < read_len; k++) {
y_local[k] = x_local[k] * cos_local[k] - x_local[k + buf_table] * sin_local[k];
y_local[k + buf_table] = x_local[k] * sin_local[k] + x_local[k + buf_table] * cos_local[k];
}
} else if (xpu_std::is_same<T, bfloat16_t>::value) {
for (int k = 0; k < read_len; k++) {
float x_0 = __bfloat162float(x_local[k]);
float x_1 = __bfloat162float(x_local[k + buf_table]);
float sin_f = __bfloat162float(sin_local[k]);
float cos_f = __bfloat162float(cos_local[k]);
y_local[k] = __float2bfloat16(x_0 * cos_f - x_1 * sin_f);
y_local[k + buf_table] = __float2bfloat16(x_0 * sin_f + x_1 * cos_f);
}
}
mfence();
LM2GM(y_local, destination + start_d_0, read_len * sizeof(T));
LM2GM(y_local + buf_table, destination + start_d_1, read_len * sizeof(T));
}
}
}
} }
template <typename T, typename Tindex> template <typename T, typename Tindex>
...@@ -87,19 +127,19 @@ void RoPE(void *destination, const void *source, ...@@ -87,19 +127,19 @@ void RoPE(void *destination, const void *source,
const void *pos_ids, const void *sin_table, const void *cos_table, const void *pos_ids, const void *sin_table, const void *cos_table,
uint32_t seqlen, uint32_t nhead, uint32_t dhead, uint32_t seqlen, uint32_t nhead, uint32_t dhead,
int32_t x_stride_seqlen, int32_t x_stride_nhead, int32_t x_stride_seqlen, int32_t x_stride_nhead,
int32_t y_stride_seqlen, int32_t y_stride_nhead, int32_t y_stride_seqlen, int32_t y_stride_nhead, bool IsGPTJ,
XPUStream stream) { XPUStream stream) {
RoPEKernel<T, Tindex><<<8, 64, stream>>>((T *)destination, (T *)source, RoPEKernel<T, Tindex><<<8, 64, stream>>>((T *)destination, (T *)source,
(Tindex *)pos_ids, (T *)sin_table, (T *)cos_table, (Tindex *)pos_ids, (T *)sin_table, (T *)cos_table,
seqlen, nhead, dhead, seqlen, nhead, dhead,
x_stride_seqlen, x_stride_nhead, x_stride_seqlen, x_stride_nhead,
y_stride_seqlen, y_stride_nhead, stream); y_stride_seqlen, y_stride_nhead, IsGPTJ, stream);
} }
#define LAUNCH_KERNEL(T, Tindex) \ #define LAUNCH_KERNEL(T, Tindex) \
RoPE<T, Tindex>(y, x, pos_ids, sin_table, cos_table, \ RoPE<T, Tindex>(y, x, pos_ids, sin_table, cos_table, \
seqlen, nhead, dhead, \ seqlen, nhead, dhead, \
x_stride_seqlen, x_stride_nhead, \ x_stride_seqlen, x_stride_nhead, \
y_stride_seqlen, y_stride_nhead, reinterpret_cast<kunlunStream_t>(stream)); y_stride_seqlen, y_stride_nhead, IsGPTJ, reinterpret_cast<kunlunStream_t>(stream));
namespace op::rope::kunlun { namespace op::rope::kunlun {
...@@ -124,10 +164,6 @@ infiniStatus_t Descriptor::create( ...@@ -124,10 +164,6 @@ infiniStatus_t Descriptor::create(
auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo); auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo);
CHECK_RESULT(result); CHECK_RESULT(result);
if (algo != INFINIOP_ROPE_ALGO_GPT_J) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
// Create descriptor // Create descriptor
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
result.take(), result.take(),
...@@ -155,6 +191,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -155,6 +191,7 @@ infiniStatus_t Descriptor::calculate(
int32_t x_stride_nhead = (int32_t)_info.x_stride_nhead; int32_t x_stride_nhead = (int32_t)_info.x_stride_nhead;
int32_t y_stride_seqlen = (int32_t)_info.y_stride_seqlen; int32_t y_stride_seqlen = (int32_t)_info.y_stride_seqlen;
int32_t y_stride_nhead = (int32_t)_info.y_stride_nhead; int32_t y_stride_nhead = (int32_t)_info.y_stride_nhead;
bool IsGPTJ = _info.algo == infiniopRoPEAlgo_t::INFINIOP_ROPE_ALGO_GPT_J;
if (_info.pos_type == INFINI_DTYPE_I32) { if (_info.pos_type == INFINI_DTYPE_I32) {
switch (_info.data_type) { switch (_info.data_type) {
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment