Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
4d3ca831
Unverified
Commit
4d3ca831
authored
Aug 20, 2025
by
PanZezhong1725
Committed by
GitHub
Aug 20, 2025
Browse files
Merge pull request #266 from InfiniTensor/issue/265
issue/265 - Implemented the RoPE operator for Cambricon Bang
parents
e883cc8a
e3669dfc
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
295 additions
and
35 deletions
+295
-35
src/infiniop/ops/rope/bang/rope_bang.h
src/infiniop/ops/rope/bang/rope_bang.h
+8
-0
src/infiniop/ops/rope/bang/rope_bang.mlu
src/infiniop/ops/rope/bang/rope_bang.mlu
+125
-0
src/infiniop/ops/rope/bang/rope_bang_kernel.mlu
src/infiniop/ops/rope/bang/rope_bang_kernel.mlu
+151
-0
src/infiniop/ops/rope/operator.cc
src/infiniop/ops/rope/operator.cc
+11
-35
No files found.
src/infiniop/ops/rope/bang/rope_bang.h
0 → 100644
View file @
4d3ca831
#ifndef __INFINIOP_ROPE_BANG_H__
#define __INFINIOP_ROPE_BANG_H__
#include "../rope.h"
DESCRIPTOR
(
bang
)
#endif // __INFINIOP_ROPE_BANG_H__
src/infiniop/ops/rope/bang/rope_bang.mlu
0 → 100644
View file @
4d3ca831
#include "../../../devices/bang/common_bang.h"
#include "rope_bang.h"
#include "rope_bang_kernel.mlu"
namespace op::rope::bang {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
auto handle = reinterpret_cast<device::bang::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
CHECK_RESULT(info);
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
0,
nullptr,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata, typename Tindex>
infiniStatus_t calculateRoPE(const RoPEInfo &info,
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
cnrtQueue_t queue) {
auto dimx = uint32_t(info.seqlen);
auto dimy = uint32_t(info.nhead);
auto table_dim = uint32_t(info.table_dim);
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
// Configure kernel launch parameters
k_dim.x = 4;
k_dim.y = 1;
k_dim.z = 1;
k_type = CNRT_FUNC_TYPE_UNION1;
// Launch kernel
ropeKernel<<<k_dim, k_type, queue>>>(
y, x, pos_ids, sin_table, cos_table,
dimx, dimy, table_dim,
info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_seqlen, info.x_stride_nhead);
cnrtQueueSync(queue);
return INFINI_STATUS_SUCCESS;
}
#define CALCULATE_ROPE(TDATA, TINDEX) \
calculateRoPE(_info, \
(TDATA *)y, \
(const TDATA *)x, \
(const TINDEX *)pos_ids, \
(const TDATA *)sin_table, \
(const TDATA *)cos_table, \
(cnrtQueue_t)stream)
#define ROPE_TYPE(TDATA) \
switch (_info.pos_type) { \
case INFINI_DTYPE_U8: \
return CALCULATE_ROPE(TDATA, uint8_t); \
case INFINI_DTYPE_U16: \
return CALCULATE_ROPE(TDATA, uint16_t); \
case INFINI_DTYPE_U32: \
return CALCULATE_ROPE(TDATA, uint32_t); \
case INFINI_DTYPE_U64: \
return CALCULATE_ROPE(TDATA, uint64_t); \
case INFINI_DTYPE_I8: \
return CALCULATE_ROPE(TDATA, int8_t); \
case INFINI_DTYPE_I16: \
return CALCULATE_ROPE(TDATA, int16_t); \
case INFINI_DTYPE_I32: \
return CALCULATE_ROPE(TDATA, int32_t); \
case INFINI_DTYPE_I64: \
return CALCULATE_ROPE(TDATA, int64_t); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *y,
const void *x,
const void *pos_ids,
const void *sin_table,
const void *cos_table,
void *stream) const {
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(half);
case INFINI_DTYPE_BF16:
ROPE_TYPE(bfloat16_t);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope::bang
src/infiniop/ops/rope/bang/rope_bang_kernel.mlu
0 → 100644
View file @
4d3ca831
#include "../../../devices/bang/common_bang.h"
__nram__ char nram_buffer[NRAM_MAX_SIZE];
template <typename Tdata>
__mlu_device__ void calculateRope(
Tdata *out, const Tdata *in,
const Tdata *sin_table, const Tdata *cos_table,
Tdata *sin_cache, Tdata *cos_cache,
Tdata *x1sin, Tdata *x0cos, Tdata *x0sin, Tdata *x1cos,
Tdata *input_0, Tdata *input_1, Tdata *input_cache,
int theta_index, int out_index, int in_index,
int chunk_size, int half_chunk_size, int data_segsize,
int src_load_stride, int dst_load_stride, int src_write_stride, int dst_write_stride) {
// Load sin/cos data
__memcpy(sin_cache, sin_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
__memcpy(cos_cache, cos_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
// Load input data
__memcpy(input_cache, in + in_index, chunk_size * sizeof(Tdata), GDRAM2NRAM);
// Split input into even and odd positions
__memcpy(input_0, input_cache, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
__memcpy(input_1, input_cache + 1, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
// Compute even positions: y0 = x0 * cos - x1 * sin and y1 = x0 * sin + x1 * cos
__bang_mul(x0cos, input_0, cos_cache, half_chunk_size);
__bang_mul(x1sin, input_1, sin_cache, half_chunk_size);
__bang_mul(x0sin, input_0, sin_cache, half_chunk_size);
__bang_mul(x1cos, input_1, cos_cache, half_chunk_size);
__bang_sub(input_0, x0cos, x1sin, half_chunk_size);
__bang_add(input_1, x0sin, x1cos, half_chunk_size);
// Interleave results back into output buffer
__memcpy(input_cache, input_0, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
__memcpy(input_cache + 1, input_1, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
// Write back results
__memcpy(out + out_index, input_cache, chunk_size * sizeof(Tdata), NRAM2GDRAM);
}
template <typename Tdata, typename Tindex>
__mlu_global__ void ropeKernel(
Tdata *y,
const Tdata *x,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table,
uint32_t seqlen,
uint32_t nhead,
uint32_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
// Calculate available NRAM space after alignment
const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * 9); // 9 buffers need alignment
const size_t max_chunk_elements = nram_usable / (9 * sizeof(Tdata));
// Key variables that determine execution path
const bool use_pos_ids_buffer = (seqlen * sizeof(Tindex) <= (nram_usable / 2));
const int half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
// Common stride configurations
const int data_segsize = sizeof(Tdata);
const int src_load_stride = 2 * sizeof(Tdata);
const int dst_load_stride = 1 * sizeof(Tdata);
const int src_write_stride = 1 * sizeof(Tdata);
const int dst_write_stride = 2 * sizeof(Tdata);
// Task distribution
const int batch_volume = seqlen * nhead;
const int remaining_tasks = batch_volume % taskDim;
const int base_tasks_per_core = batch_volume / taskDim;
const int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
const int task_start_idx = (taskId < remaining_tasks ? taskId * base_tasks_per_core + taskId : taskId * base_tasks_per_core + remaining_tasks);
// NRAM buffer allocation with proper alignment
char *aligned_nram = (char *)(((size_t)nram_buffer + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
// Setup position IDs if they fit in NRAM
Tindex *srcP = nullptr;
if (use_pos_ids_buffer) {
srcP = (Tindex *)aligned_nram;
__memcpy(srcP, pos_ids, seqlen * sizeof(Tindex), GDRAM2NRAM);
aligned_nram = (char *)(((size_t)srcP + seqlen * sizeof(Tindex) + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
}
// Main processing buffers (pointers will be set per chunk)
Tdata *sin_cache = nullptr;
Tdata *cos_cache = nullptr;
Tdata *x1sin = nullptr;
Tdata *x0cos = nullptr;
Tdata *x0sin = nullptr;
Tdata *x1cos = nullptr;
Tdata *input_0 = nullptr;
Tdata *input_1 = nullptr;
Tdata *input_cache = nullptr;
// Main processing loop
for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
// Calculate output and input indices
int seq_idx = i / nhead;
int head_idx = i % nhead;
// Output indices (y)
int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
// Input indices (x)
int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
// Get position index
Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
int rot_offset = pos_idx * table_dim;
// Process in chunks that fit in NRAM
int processed = 0;
while (processed < table_dim) {
// Calculate current chunk size
int current_half_chunk = std::min<uint32_t>(half_chunk_size, table_dim - processed);
int current_chunk_size = 2 * current_half_chunk;
int theta_offset = rot_offset + processed;
int dst_offset = out_offset + processed * 2;
int src_offset = in_offset + processed * 2;
// Set up NRAM buffers for this chunk
char *chunk_base = aligned_nram;
sin_cache = (Tdata *)chunk_base;
cos_cache = sin_cache + current_half_chunk;
x1sin = cos_cache + current_half_chunk;
x0cos = x1sin + current_half_chunk;
x0sin = x0cos + current_half_chunk;
x1cos = x0sin + current_half_chunk;
input_0 = x1cos + current_half_chunk;
input_1 = input_0 + current_half_chunk;
input_cache = input_1 + current_half_chunk;
calculateRope<Tdata>(
y, x, sin_table, cos_table,
sin_cache, cos_cache, x1sin, x0cos, x0sin, x1cos,
input_0, input_1, input_cache,
theta_offset, dst_offset, src_offset,
current_chunk_size, current_half_chunk,
data_segsize,
src_load_stride, dst_load_stride, src_write_stride, dst_write_stride);
processed += current_half_chunk;
}
}
}
src/infiniop/ops/rope/operator.cc
View file @
4d3ca831
...
...
@@ -11,6 +11,9 @@
#ifdef ENABLE_ASCEND_API
#include "ascend/rope_ascend.h"
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/rope_bang.h"
#endif
#ifdef ENABLE_METAX_API
#include "metax/rope_metax.h"
#endif
...
...
@@ -51,12 +54,8 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
#ifdef ENABLE_ASCEND_API
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
return
bangCreateRoPEDescriptor
((
BangHandle_t
)
handle
,
(
RoPEBangDescriptor_t
*
)
desc_ptr
,
t
,
pos_ids
,
sin_table
,
cos_table
);
}
#ifdef ENABLE_CAMBRICON_API
CREATE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
#ifdef ENABLE_MTHREADS_GPU
case
DevMthreadsGpu
:
{
...
...
@@ -92,19 +91,12 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
return
bangGetRoPEWorkspaceSize
((
RoPEBangDescriptor_t
)
desc
,
size
);
}
#ifdef ENABLE_CAMBRICON_API
GET
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
#ifdef ENABLE_ASCEND_API
GET
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
return
macaGetRoPEWorkspaceSize
((
RoPEMacaDescriptor_t
)
desc
,
size
);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case
DevMthreadsGpu
:
{
return
musaGetRoPEWorkspaceSize
((
RoPEMusaDescriptor_t
)
desc
,
size
);
...
...
@@ -146,21 +138,12 @@ __C infiniStatus_t infiniopRoPE(
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
return
bangRoPE
((
RoPEBangDescriptor_t
)
desc
,
workspace
,
workspace_size
,
t
,
pos_ids
,
sin_table
,
cos_table
,
stream
);
}
#ifdef ENABLE_CAMBRICON_API
CALCULATE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
#ifdef ENABLE_ASCEND_API
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
return
macaRoPE
((
RoPEMacaDescriptor_t
)
desc
,
workspace
,
workspace_size
,
t
,
pos_ids
,
sin_table
,
cos_table
,
stream
);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case
DevMthreadsGpu
:
{
return
musaRoPE
((
RoPEMusaDescriptor_t
)
desc
,
workspace
,
workspace_size
,
...
...
@@ -195,19 +178,12 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
return
bangDestroyRoPEDescriptor
((
RoPEBangDescriptor_t
)
desc
);
}
#ifdef ENABLE_CAMBRICON_API
DELETE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
#ifdef ENABLE_ASCEND_API
DELETE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
return
macaDestroyRoPEDescriptor
((
RoPEMacaDescriptor_t
)
desc
);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case
DevMthreadsGpu
:
{
return
musaDestroyRoPEDescriptor
((
RoPEMusaDescriptor_t
)
desc
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment