Unverified Commit c112132e authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #839 from InfiniTensor/issue/838

issue/838 - Cambricon Batched RoPE
parents d3e27d8c 5848b408
......@@ -40,8 +40,9 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
const Tdata *sin_table,
const Tdata *cos_table,
cnrtQueue_t queue) {
auto dimx = uint32_t(info.seqlen);
auto dimy = uint32_t(info.nhead);
auto batch_size = uint32_t(info.batch);
auto seqlen = uint32_t(info.seqlen);
auto nhead = uint32_t(info.nhead);
auto table_dim = uint32_t(info.table_dim);
cnrtDim3_t k_dim;
......@@ -53,12 +54,12 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
k_dim.z = 1;
k_type = CNRT_FUNC_TYPE_UNION1;
// Launch kernel
// Launch kernel with batch dimension
ropeKernel<<<k_dim, k_type, queue>>>(
y, x, pos_ids, sin_table, cos_table,
dimx, dimy, table_dim,
info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_seqlen, info.x_stride_nhead,
batch_size, seqlen, nhead, table_dim,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead,
info.algo);
cnrtQueueSync(queue);
......
......@@ -62,11 +62,14 @@ __mlu_global__ void ropeKernel(
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
uint32_t batch_size,
uint32_t seqlen,
uint32_t nhead,
uint32_t table_dim,
ptrdiff_t y_stride_batch,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_batch,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead,
infiniopRoPEAlgo_t algo) {
......@@ -106,7 +109,7 @@ __mlu_global__ void ropeKernel(
}
// Task distribution
const int batch_volume = seqlen * nhead;
const int batch_volume = batch_size * seqlen * nhead;
const int remaining_tasks = batch_volume % taskDim;
const int base_tasks_per_core = batch_volume / taskDim;
const int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
......@@ -136,13 +139,35 @@ __mlu_global__ void ropeKernel(
// Main processing loop
for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
int seq_idx = i / nhead;
// Calculate 3D indices from flattened task index
int batch_idx = i / (seqlen * nhead);
int seq_idx = (i % (seqlen * nhead)) / nhead;
int head_idx = i % nhead;
int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
// Calculate offsets with batch dimension
// Note: For GPT-NeoX, the stride calculations might be different
int out_offset = batch_idx * y_stride_batch + seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
int in_offset = batch_idx * x_stride_batch + seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
// Get position index for this sequence
// Position IDs are shared across batches or per batch depending on input
Tindex pos_idx;
if (use_pos_ids_buffer) {
// Position IDs loaded in NRAM
pos_idx = srcP[seq_idx];
} else {
// Position IDs in global memory
// Handle both cases: position IDs shape could be [seqlen] or [batch_size, seqlen]
if (batch_size > 1) {
// Assume position IDs have shape [batch_size, seqlen]
int pos_flat_idx = batch_idx * seqlen + seq_idx;
pos_idx = pos_ids[pos_flat_idx];
} else {
// Single batch case: position IDs shape is [seqlen]
pos_idx = pos_ids[seq_idx];
}
}
Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
int rot_offset = pos_idx * table_dim;
int processed = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment