Unverified Commit 20a2dbd6 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #478 from InfiniTensor/issue/477

issue/477 - Cambricon MLU NeoX
parents 6b903fd9 6af2e427
......@@ -21,10 +21,6 @@ infiniStatus_t Descriptor::create(
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo);
CHECK_RESULT(info);
if (algo != INFINIOP_ROPE_ALGO_GPT_J) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
......@@ -62,7 +58,8 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
y, x, pos_ids, sin_table, cos_table,
dimx, dimy, table_dim,
info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_seqlen, info.x_stride_nhead);
info.x_stride_seqlen, info.x_stride_nhead,
info.algo);
cnrtQueueSync(queue);
......
#include "../../../devices/bang/common_bang.h"
#include "rope_bang.h"
__nram__ char nram_buffer[NRAM_MAX_SIZE];
......@@ -11,7 +12,9 @@ __mlu_device__ void calculateRope(
Tdata *input_0, Tdata *input_1, Tdata *input_cache,
int theta_index, int out_index, int in_index,
int chunk_size, int half_chunk_size, int data_segsize,
int src_load_stride, int dst_load_stride, int src_write_stride, int dst_write_stride) {
int src_load_stride, int dst_load_stride, int src_write_stride, int dst_write_stride,
bool is_gpt_j_style) {
// Load sin/cos data
__memcpy(sin_cache, sin_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
__memcpy(cos_cache, cos_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
......@@ -19,11 +22,18 @@ __mlu_device__ void calculateRope(
// Load input data
__memcpy(input_cache, in + in_index, chunk_size * sizeof(Tdata), GDRAM2NRAM);
// Split input into even and odd positions
__memcpy(input_0, input_cache, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
__memcpy(input_1, input_cache + 1, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
if (is_gpt_j_style) {
// GPT-J: (x0, x1), (x2, x3), ...
// Split input into even and odd positions
__memcpy(input_0, input_cache, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
__memcpy(input_1, input_cache + 1, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
} else {
// GPT-NeoX: (x0...xd/2-1), (xd/2...xd-1)
__memcpy(input_0, input_cache, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
__memcpy(input_1, input_cache + half_chunk_size, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
}
// Compute even positions: y0 = x0 * cos - x1 * sin and y1 = x0 * sin + x1 * cos
// Compute rotations
__bang_mul(x0cos, input_0, cos_cache, half_chunk_size);
__bang_mul(x1sin, input_1, sin_cache, half_chunk_size);
__bang_mul(x0sin, input_0, sin_cache, half_chunk_size);
......@@ -31,9 +41,15 @@ __mlu_device__ void calculateRope(
__bang_sub(input_0, x0cos, x1sin, half_chunk_size);
__bang_add(input_1, x0sin, x1cos, half_chunk_size);
// Interleave results back into output buffer
__memcpy(input_cache, input_0, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
__memcpy(input_cache + 1, input_1, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
if (is_gpt_j_style) {
// GPT-J
__memcpy(input_cache, input_0, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
__memcpy(input_cache + 1, input_1, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
} else {
// GPT-NeoX
__memcpy(input_cache, input_0, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
__memcpy(input_cache + half_chunk_size, input_1, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
}
// Write back results
__memcpy(out + out_index, input_cache, chunk_size * sizeof(Tdata), NRAM2GDRAM);
......@@ -52,22 +68,42 @@ __mlu_global__ void ropeKernel(
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ptrdiff_t x_stride_nhead,
infiniopRoPEAlgo_t algo) {
const bool is_gpt_j_style = (algo == INFINIOP_ROPE_ALGO_GPT_J);
// Calculate available NRAM space after alignment
const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * 9); // 9 buffers need alignment
const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * 9);
const size_t max_chunk_elements = nram_usable / (9 * sizeof(Tdata));
// Key variables that determine execution path
const bool use_pos_ids_buffer = (seqlen * sizeof(Tindex) <= (nram_usable / 2));
const int half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
// Common stride configurations
const int data_segsize = sizeof(Tdata);
const int src_load_stride = 2 * sizeof(Tdata);
const int dst_load_stride = 1 * sizeof(Tdata);
const int src_write_stride = 1 * sizeof(Tdata);
const int dst_write_stride = 2 * sizeof(Tdata);
int half_chunk_size;
if (is_gpt_j_style) {
half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
} else {
half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
}
int data_segsize, src_load_stride, dst_load_stride, src_write_stride, dst_write_stride;
if (is_gpt_j_style) {
// GPT-J
data_segsize = sizeof(Tdata);
src_load_stride = 2 * sizeof(Tdata);
dst_load_stride = 1 * sizeof(Tdata);
src_write_stride = 1 * sizeof(Tdata);
dst_write_stride = 2 * sizeof(Tdata);
} else {
// GPT-NeoX
data_segsize = half_chunk_size * sizeof(Tdata);
src_load_stride = 1 * sizeof(Tdata);
dst_load_stride = 1 * sizeof(Tdata);
src_write_stride = 1 * sizeof(Tdata);
dst_write_stride = 1 * sizeof(Tdata);
}
// Task distribution
const int batch_volume = seqlen * nhead;
......@@ -100,29 +136,29 @@ __mlu_global__ void ropeKernel(
// Main processing loop
for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
// Calculate output and input indices
int seq_idx = i / nhead;
int head_idx = i % nhead;
// Output indices (y)
int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
// Input indices (x)
int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
// Get position index
Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
int rot_offset = pos_idx * table_dim;
// Process in chunks that fit in NRAM
int processed = 0;
while (processed < table_dim) {
// Calculate current chunk size
int current_half_chunk = std::min<uint32_t>(half_chunk_size, table_dim - processed);
int current_chunk_size = 2 * current_half_chunk;
int theta_offset = rot_offset + processed;
int dst_offset = out_offset + processed * 2;
int src_offset = in_offset + processed * 2;
int dst_offset, src_offset;
if (is_gpt_j_style) {
dst_offset = out_offset + processed * 2;
src_offset = in_offset + processed * 2;
} else {
dst_offset = out_offset + processed;
src_offset = in_offset + processed;
}
// Set up NRAM buffers for this chunk
char *chunk_base = aligned_nram;
......@@ -143,7 +179,8 @@ __mlu_global__ void ropeKernel(
theta_offset, dst_offset, src_offset,
current_chunk_size, current_half_chunk,
data_segsize,
src_load_stride, dst_load_stride, src_write_stride, dst_write_stride);
src_load_stride, dst_load_stride, src_write_stride, dst_write_stride,
is_gpt_j_style);
processed += current_half_chunk;
}
......
......@@ -97,7 +97,6 @@ def rotary_embedding(ans, t, sin, cos, device, algo):
return t_out_1, t_out_2
dh = t.shape[-1]
dt = t.dtype
assert dh % 2 == 0, "Embedding dimension must be even."
......@@ -111,7 +110,7 @@ def rotary_embedding(ans, t, sin, cos, device, algo):
ans[..., 0::2] = t_out_even.to(dt)
ans[..., 1::2] = t_out_odd.to(dt)
else:
half_dim = dh // 2
half_dim = dh // 2
t_first = t[..., :half_dim]
t_second = t[..., half_dim:]
......@@ -232,6 +231,7 @@ def test(
sin_table.torch_tensor(),
cos_table.torch_tensor(),
device,
algo,
),
device,
NUM_PRERUN,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment