"...src/git@developer.sourcefind.cn:tsoc/hg-misc-tools.git" did not exist on "43f3a022cdeee8b51cf106b3ce1851a9d821fa5c"
Commit 5848b408 authored by wooway777's avatar wooway777
Browse files

issue/838 - Cambricon Batched RoPE

parent 12cde8eb
...@@ -40,8 +40,9 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, ...@@ -40,8 +40,9 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
const Tdata *sin_table, const Tdata *sin_table,
const Tdata *cos_table, const Tdata *cos_table,
cnrtQueue_t queue) { cnrtQueue_t queue) {
auto dimx = uint32_t(info.seqlen); auto batch_size = uint32_t(info.batch);
auto dimy = uint32_t(info.nhead); auto seqlen = uint32_t(info.seqlen);
auto nhead = uint32_t(info.nhead);
auto table_dim = uint32_t(info.table_dim); auto table_dim = uint32_t(info.table_dim);
cnrtDim3_t k_dim; cnrtDim3_t k_dim;
...@@ -53,12 +54,12 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, ...@@ -53,12 +54,12 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
k_dim.z = 1; k_dim.z = 1;
k_type = CNRT_FUNC_TYPE_UNION1; k_type = CNRT_FUNC_TYPE_UNION1;
// Launch kernel // Launch kernel with batch dimension
ropeKernel<<<k_dim, k_type, queue>>>( ropeKernel<<<k_dim, k_type, queue>>>(
y, x, pos_ids, sin_table, cos_table, y, x, pos_ids, sin_table, cos_table,
dimx, dimy, table_dim, batch_size, seqlen, nhead, table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_seqlen, info.x_stride_nhead, info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead,
info.algo); info.algo);
cnrtQueueSync(queue); cnrtQueueSync(queue);
......
...@@ -62,11 +62,14 @@ __mlu_global__ void ropeKernel( ...@@ -62,11 +62,14 @@ __mlu_global__ void ropeKernel(
const Tindex *pos_ids, const Tindex *pos_ids,
const Tdata *sin_table, const Tdata *sin_table,
const Tdata *cos_table, const Tdata *cos_table,
uint32_t batch_size,
uint32_t seqlen, uint32_t seqlen,
uint32_t nhead, uint32_t nhead,
uint32_t table_dim, uint32_t table_dim,
ptrdiff_t y_stride_batch,
ptrdiff_t y_stride_seqlen, ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead, ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_batch,
ptrdiff_t x_stride_seqlen, ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead, ptrdiff_t x_stride_nhead,
infiniopRoPEAlgo_t algo) { infiniopRoPEAlgo_t algo) {
...@@ -106,7 +109,7 @@ __mlu_global__ void ropeKernel( ...@@ -106,7 +109,7 @@ __mlu_global__ void ropeKernel(
} }
// Task distribution // Task distribution
const int batch_volume = seqlen * nhead; const int batch_volume = batch_size * seqlen * nhead;
const int remaining_tasks = batch_volume % taskDim; const int remaining_tasks = batch_volume % taskDim;
const int base_tasks_per_core = batch_volume / taskDim; const int base_tasks_per_core = batch_volume / taskDim;
const int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0); const int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
...@@ -136,13 +139,35 @@ __mlu_global__ void ropeKernel( ...@@ -136,13 +139,35 @@ __mlu_global__ void ropeKernel(
// Main processing loop // Main processing loop
for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) { for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
int seq_idx = i / nhead; // Calculate 3D indices from flattened task index
int batch_idx = i / (seqlen * nhead);
int seq_idx = (i % (seqlen * nhead)) / nhead;
int head_idx = i % nhead; int head_idx = i % nhead;
int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead; // Calculate offsets with batch dimension
int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead; // Note: For GPT-NeoX, the stride calculations might be different
int out_offset = batch_idx * y_stride_batch + seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
int in_offset = batch_idx * x_stride_batch + seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
// Get position index for this sequence
// Position IDs are shared across batches or per batch depending on input
Tindex pos_idx;
if (use_pos_ids_buffer) {
// Position IDs loaded in NRAM
pos_idx = srcP[seq_idx];
} else {
// Position IDs in global memory
// Handle both cases: position IDs shape could be [seqlen] or [batch_size, seqlen]
if (batch_size > 1) {
// Assume position IDs have shape [batch_size, seqlen]
int pos_flat_idx = batch_idx * seqlen + seq_idx;
pos_idx = pos_ids[pos_flat_idx];
} else {
// Single batch case: position IDs shape is [seqlen]
pos_idx = pos_ids[seq_idx];
}
}
Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
int rot_offset = pos_idx * table_dim; int rot_offset = pos_idx * table_dim;
int processed = 0; int processed = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment