Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
c112132e
Unverified
Commit
c112132e
authored
Feb 11, 2026
by
thatPepe
Committed by
GitHub
Feb 11, 2026
Browse files
Merge pull request #839 from InfiniTensor/issue/838
issue/838 - Cambricon Batched RoPE
parents
d3e27d8c
5848b408
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
11 deletions
+37
-11
src/infiniop/ops/rope/bang/rope_bang.mlu
src/infiniop/ops/rope/bang/rope_bang.mlu
+7
-6
src/infiniop/ops/rope/bang/rope_bang_kernel.mlu
src/infiniop/ops/rope/bang/rope_bang_kernel.mlu
+30
-5
No files found.
src/infiniop/ops/rope/bang/rope_bang.mlu
View file @
c112132e
...
...
@@ -40,8 +40,9 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
const Tdata *sin_table,
const Tdata *cos_table,
cnrtQueue_t queue) {
auto dimx = uint32_t(info.seqlen);
auto dimy = uint32_t(info.nhead);
auto batch_size = uint32_t(info.batch);
auto seqlen = uint32_t(info.seqlen);
auto nhead = uint32_t(info.nhead);
auto table_dim = uint32_t(info.table_dim);
cnrtDim3_t k_dim;
...
...
@@ -53,12 +54,12 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
k_dim.z = 1;
k_type = CNRT_FUNC_TYPE_UNION1;
// Launch kernel
// Launch kernel
with batch dimension
ropeKernel<<<k_dim, k_type, queue>>>(
y, x, pos_ids, sin_table, cos_table,
dimx, dimy
, table_dim,
info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_seqlen, info.x_stride_nhead,
batch_size, seqlen, nhead
, table_dim,
info.y_stride_batch,
info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch,
info.x_stride_seqlen, info.x_stride_nhead,
info.algo);
cnrtQueueSync(queue);
...
...
src/infiniop/ops/rope/bang/rope_bang_kernel.mlu
View file @
c112132e
...
...
@@ -62,11 +62,14 @@ __mlu_global__ void ropeKernel(
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
uint32_t batch_size,
uint32_t seqlen,
uint32_t nhead,
uint32_t table_dim,
ptrdiff_t y_stride_batch,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_batch,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead,
infiniopRoPEAlgo_t algo) {
...
...
@@ -106,7 +109,7 @@ __mlu_global__ void ropeKernel(
}
// Task distribution
const int batch_volume = seqlen * nhead;
const int batch_volume =
batch_size *
seqlen * nhead;
const int remaining_tasks = batch_volume % taskDim;
const int base_tasks_per_core = batch_volume / taskDim;
const int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
...
...
@@ -136,13 +139,35 @@ __mlu_global__ void ropeKernel(
// Main processing loop
for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
int seq_idx = i / nhead;
// Calculate 3D indices from flattened task index
int batch_idx = i / (seqlen * nhead);
int seq_idx = (i % (seqlen * nhead)) / nhead;
int head_idx = i % nhead;
int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
// Calculate offsets with batch dimension
// Note: For GPT-NeoX, the stride calculations might be different
int out_offset = batch_idx * y_stride_batch + seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
int in_offset = batch_idx * x_stride_batch + seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
// Get position index for this sequence
// Position IDs are shared across batches or per batch depending on input
Tindex pos_idx;
if (use_pos_ids_buffer) {
// Position IDs loaded in NRAM
pos_idx = srcP[seq_idx];
} else {
// Position IDs in global memory
// Handle both cases: position IDs shape could be [seqlen] or [batch_size, seqlen]
if (batch_size > 1) {
// Assume position IDs have shape [batch_size, seqlen]
int pos_flat_idx = batch_idx * seqlen + seq_idx;
pos_idx = pos_ids[pos_flat_idx];
} else {
// Single batch case: position IDs shape is [seqlen]
pos_idx = pos_ids[seq_idx];
}
}
Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
int rot_offset = pos_idx * table_dim;
int processed = 0;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment