Unverified Commit 51beebc6 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Issue/739 在cpu, nvidia, metax, moore threads支持batched rope (#743)

* issue/739 - support batched RoPE on Nvidia and CPU

* issue/739 - metax, moore batched rope

* issue/739 - adjust metax flags

* issue/739 - added a rope module interface to forward inplace in output tensor
parent f73d6237
......@@ -52,6 +52,24 @@ public:
*/
Tensor forward(const Tensor &x, const Tensor &pos, bool in_place = false) const;
/**
* @brief Forward pass: apply RoPE to a tensor in place
*
* @param y Output tensor of shape (..., head_dim) where ... is any number of dimensions
* @param x Input tensor of shape (..., head_dim) where ... is any number of dimensions
* @param pos Position IDs tensor of shape (*,) typically [seq_len] or [batch, seq_len]
* @return Rotated tensor with same shape as input
*
* Applies rotary position embeddings to the input tensor.
* For attention mechanisms, call this method separately for query and key tensors.
*
* Common input shapes:
* - [batch, num_heads, seq_len, head_dim]
* - [batch, seq_len, num_heads, head_dim]
* - [seq_len, head_dim]
*/
Tensor forward(const Tensor &y, const Tensor &x, const Tensor &pos) const;
// Module information
size_t head_dim() const { return head_dim_; }
size_t max_seq_len() const { return max_seq_len_; }
......
......@@ -122,6 +122,11 @@ Tensor RoPE::forward(const Tensor &x, const Tensor &pos, bool in_place) const {
return op::rope(x, pos, sin_cache_, cos_cache_, algo_);
}
Tensor RoPE::forward(const Tensor &y, const Tensor &x, const Tensor &pos) const {
op::rope_(y, x, pos, sin_cache_, cos_cache_, algo_);
return y;
}
std::string RoPE::extra_repr() const {
std::string algo_str = (algo_ == Algo::GPT_J) ? "GPT_J" : "GPT_NEOX";
return "RoPE(head_dim=" + std::to_string(head_dim_) + ", max_seq_len=" + std::to_string(max_seq_len_) + ", theta=" + std::to_string(theta_) + ", algo=" + algo_str + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
......
......@@ -38,19 +38,48 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
const Tindex *pos_ids,
const Tdata *sin_table,
const Tdata *cos_table) {
// Calculate position ID stride for batch dimension
size_t pos_stride_batch = info.pos_has_batch_dim ? info.seqlen : 0;
// Parallelize over batch and head dimensions - remove collapse clause
#pragma omp parallel for
for (ptrdiff_t b = 0; b < ptrdiff_t(info.batch); b++) {
for (ptrdiff_t h = 0; h < ptrdiff_t(info.nhead); h++) {
for (size_t tok = 0; tok < info.seqlen; tok++) {
size_t x_offset = tok * info.x_stride_seqlen + h * info.x_stride_nhead;
size_t y_offset = tok * info.y_stride_seqlen + h * info.y_stride_nhead;
size_t pos_id = size_t(pos_ids[tok]);
// Calculate memory offsets with batch dimension
size_t x_offset = (info.has_batch_dim ? b * info.x_stride_batch : 0) + tok * info.x_stride_seqlen + h * info.x_stride_nhead;
size_t y_offset = (info.has_batch_dim ? b * info.y_stride_batch : 0) + tok * info.y_stride_seqlen + h * info.y_stride_nhead;
// Calculate position ID offset
size_t pos_offset;
if (info.pos_has_batch_dim) {
// Per-batch position IDs
pos_offset = b * pos_stride_batch + tok;
} else {
// Shared position IDs across batch
pos_offset = tok;
}
size_t pos_id = size_t(pos_ids[pos_offset]);
size_t table_offset = pos_id * info.table_dim;
for (size_t i = 0; i < info.table_dim; i++) {
size_t pos0 = info.algo == infiniopRoPEAlgo_t::INFINIOP_ROPE_ALGO_GPT_J ? 2 * i : i;
size_t pos1 = info.algo == infiniopRoPEAlgo_t::INFINIOP_ROPE_ALGO_GPT_J ? 2 * i + 1 : i + info.table_dim;
// Calculate positions based on algorithm
size_t pos0, pos1;
if (info.algo == infiniopRoPEAlgo_t::INFINIOP_ROPE_ALGO_GPT_J) {
// GPT-J style: interleaved pairs
pos0 = 2 * i;
pos1 = 2 * i + 1;
} else {
// Original style: first half and second half
pos0 = i;
pos1 = i + info.table_dim;
}
if constexpr (std::is_same<Tdata, fp16_t>::value || std::is_same<Tdata, bf16_t>::value) {
// Convert to float for computation
float x0 = utils::cast<float>(x[x_offset + pos0]),
x1 = utils::cast<float>(x[x_offset + pos1]),
sin__ = utils::cast<float>(sin_table[table_offset + i]),
......@@ -59,6 +88,7 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
y[y_offset + pos0] = utils::cast<Tdata>(x0 * cos__ - x1 * sin__);
y[y_offset + pos1] = utils::cast<Tdata>(x0 * sin__ + x1 * cos__);
} else {
// Use native types
Tdata x0 = x[x_offset + pos0],
x1 = x[x_offset + pos1],
sin__ = sin_table[table_offset + i],
......@@ -70,6 +100,7 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
}
}
}
}
return INFINI_STATUS_SUCCESS;
}
......@@ -121,6 +152,8 @@ infiniStatus_t Descriptor::calculate(
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
#undef ROPE_TYPE
......
......@@ -9,14 +9,37 @@ __device__ void ropeThreadPerItemBlock(
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
size_t pos_stride_batch, // Stride for batch dimension in pos_ids (0 if 1D)
bool pos_has_batch_dim, // Whether pos_ids has batch dimension
bool has_batch_dim, // Whether tensors have batch dimension
ptrdiff_t y_stride_batch,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_batch,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
auto y_offset = blockIdx.x * y_stride_seqlen + blockIdx.y * y_stride_nhead;
auto x_offset = blockIdx.x * x_stride_seqlen + blockIdx.y * x_stride_nhead;
size_t pos_id = size_t(pos_ids[blockIdx.x]);
// Calculate batch index: use blockIdx.z for 4D, 0 for 3D
const size_t batch_idx = has_batch_dim ? blockIdx.z : 0;
const size_t seq_idx = blockIdx.x;
const size_t head_idx = blockIdx.y;
// Calculate memory offsets
auto y_offset = (has_batch_dim ? batch_idx * y_stride_batch : 0) + seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
auto x_offset = (has_batch_dim ? batch_idx * x_stride_batch : 0) + seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
// Calculate position ID offset
size_t pos_offset;
if (pos_has_batch_dim) {
// Per-batch position IDs
pos_offset = batch_idx * pos_stride_batch + seq_idx;
} else {
// Shared position IDs across batch
pos_offset = seq_idx;
}
size_t pos_id = size_t(pos_ids[pos_offset]);
auto table_offset = pos_id * table_dim;
for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) {
......
......@@ -13,16 +13,24 @@ INFINIOP_METAX_KERNEL ropeThreadPerItemKernel(
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
size_t pos_stride_batch, // Stride for batch dimension in pos_ids
bool pos_has_batch_dim, // Whether pos_ids has batch dimension
bool has_batch_dim, // Whether tensors have batch dimension
ptrdiff_t y_stride_batch,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_batch,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ropeThreadPerItemBlock<IsGPTJ>(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
y_stride_seqlen, y_stride_nhead,
x_stride_seqlen, x_stride_nhead);
pos_stride_batch,
pos_has_batch_dim,
has_batch_dim,
y_stride_batch, y_stride_seqlen, y_stride_nhead,
x_stride_batch, x_stride_seqlen, x_stride_nhead);
}
namespace op::rope::metax {
......@@ -70,19 +78,43 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
const Tdata *sin_table,
const Tdata *cos_table,
hcStream_t stream) {
auto dimx = uint32_t(info.seqlen),
dimy = uint32_t(info.nhead);
auto dimx = uint32_t(info.seqlen);
auto dimy = uint32_t(info.nhead);
// For 3D tensors, batch dimension is not in blockIdx
auto dimz = info.has_batch_dim ? uint32_t(info.batch) : 1;
int nthreads = std::max(int(info.table_dim), block_size);
bool is_gpt_j = info.algo == infiniopRoPEAlgo_t::INFINIOP_ROPE_ALGO_GPT_J;
// Calculate position ID stride for batch dimension
size_t pos_stride_batch = info.pos_has_batch_dim ? info.seqlen : 0;
dim3 grid_dim;
if (info.has_batch_dim) {
// 4D tensors: use 3D grid [seqlen, nhead, batch]
grid_dim = dim3(dimx, dimy, dimz);
} else {
// 3D tensors: use 2D grid [seqlen, nhead], batch dimension is 1
grid_dim = dim3(dimx, dimy);
}
if (is_gpt_j) {
ropeThreadPerItemKernel<true><<<dim3(dimx, dimy), nthreads, 0, stream>>>(
ropeThreadPerItemKernel<true><<<grid_dim, nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
pos_stride_batch,
info.pos_has_batch_dim,
info.has_batch_dim,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
} else {
ropeThreadPerItemKernel<false><<<dim3(dimx, dimy), nthreads, 0, stream>>>(
ropeThreadPerItemKernel<false><<<grid_dim, nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
pos_stride_batch,
info.pos_has_batch_dim,
info.has_batch_dim,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
}
return INFINI_STATUS_SUCCESS;
......
......@@ -16,14 +16,37 @@ __device__ void ropeThreadPerItemBlock(
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
size_t pos_stride_batch, // Stride for batch dimension in pos_ids (0 if 1D)
bool pos_has_batch_dim, // Whether pos_ids has batch dimension
bool has_batch_dim, // Whether tensors have batch dimension
ptrdiff_t y_stride_batch,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_batch,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
auto y_offset = blockIdx.x * y_stride_seqlen + blockIdx.y * y_stride_nhead;
auto x_offset = blockIdx.x * x_stride_seqlen + blockIdx.y * x_stride_nhead;
size_t pos_id = size_t(pos_ids[blockIdx.x]);
// Calculate batch index: use blockIdx.z for 4D, 0 for 3D
const size_t batch_idx = has_batch_dim ? blockIdx.z : 0;
const size_t seq_idx = blockIdx.x;
const size_t head_idx = blockIdx.y;
// Calculate memory offsets
auto y_offset = (has_batch_dim ? batch_idx * y_stride_batch : 0) + seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
auto x_offset = (has_batch_dim ? batch_idx * x_stride_batch : 0) + seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
// Calculate position ID offset
size_t pos_offset;
if (pos_has_batch_dim) {
// Per-batch position IDs
pos_offset = batch_idx * pos_stride_batch + seq_idx;
} else {
// Shared position IDs across batch
pos_offset = seq_idx;
}
size_t pos_id = size_t(pos_ids[pos_offset]);
auto table_offset = pos_id * table_dim;
for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) {
......
......@@ -13,16 +13,24 @@ INFINIOP_MOORE_KERNEL ropeThreadPerItemKernel(
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
size_t pos_stride_batch, // Stride for batch dimension in pos_ids
bool pos_has_batch_dim, // Whether pos_ids has batch dimension
bool has_batch_dim, // Whether tensors have batch dimension
ptrdiff_t y_stride_batch,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_batch,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ropeThreadPerItemBlock<IsGPTJ>(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
y_stride_seqlen, y_stride_nhead,
x_stride_seqlen, x_stride_nhead);
pos_stride_batch,
pos_has_batch_dim,
has_batch_dim,
y_stride_batch, y_stride_seqlen, y_stride_nhead,
x_stride_batch, x_stride_seqlen, x_stride_nhead);
}
namespace op::rope::moore {
......@@ -70,19 +78,43 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
const Tdata *sin_table,
const Tdata *cos_table,
musaStream_t stream) {
auto dimx = uint32_t(info.seqlen),
dimy = uint32_t(info.nhead);
auto dimx = uint32_t(info.seqlen);
auto dimy = uint32_t(info.nhead);
// For 3D tensors, batch dimension is not in blockIdx
auto dimz = info.has_batch_dim ? uint32_t(info.batch) : 1;
int nthreads = std::max(int(info.table_dim), block_size);
bool is_gpt_j = info.algo == infiniopRoPEAlgo_t::INFINIOP_ROPE_ALGO_GPT_J;
// Calculate position ID stride for batch dimension
size_t pos_stride_batch = info.pos_has_batch_dim ? info.seqlen : 0;
dim3 grid_dim;
if (info.has_batch_dim) {
// 4D tensors: use 3D grid [seqlen, nhead, batch]
grid_dim = dim3(dimx, dimy, dimz);
} else {
// 3D tensors: use 2D grid [seqlen, nhead], batch dimension is 1
grid_dim = dim3(dimx, dimy);
}
if (is_gpt_j) {
ropeThreadPerItemKernel<true><<<dim3(dimx, dimy), nthreads, 0, stream>>>(
ropeThreadPerItemKernel<true><<<grid_dim, nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
pos_stride_batch,
info.pos_has_batch_dim,
info.has_batch_dim,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
} else {
ropeThreadPerItemKernel<false><<<dim3(dimx, dimy), nthreads, 0, stream>>>(
ropeThreadPerItemKernel<false><<<grid_dim, nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
pos_stride_batch,
info.pos_has_batch_dim,
info.has_batch_dim,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
}
return INFINI_STATUS_SUCCESS;
......
......@@ -13,16 +13,24 @@ INFINIOP_CUDA_KERNEL ropeThreadPerItemKernel(
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
size_t pos_stride_batch, // Stride for batch dimension in pos_ids
bool pos_has_batch_dim, // Whether pos_ids has batch dimension
bool has_batch_dim, // Whether tensors have batch dimension
ptrdiff_t y_stride_batch,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_batch,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ropeThreadPerItemBlock<IsGPTJ>(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
y_stride_seqlen, y_stride_nhead,
x_stride_seqlen, x_stride_nhead);
pos_stride_batch,
pos_has_batch_dim,
has_batch_dim,
y_stride_batch, y_stride_seqlen, y_stride_nhead,
x_stride_batch, x_stride_seqlen, x_stride_nhead);
}
namespace op::rope::nvidia {
......@@ -70,19 +78,43 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
const Tdata *sin_table,
const Tdata *cos_table,
cudaStream_t stream) {
auto dimx = uint32_t(info.seqlen),
dimy = uint32_t(info.nhead);
auto dimx = uint32_t(info.seqlen);
auto dimy = uint32_t(info.nhead);
// For 3D tensors, batch dimension is not in blockIdx
auto dimz = info.has_batch_dim ? uint32_t(info.batch) : 1;
int nthreads = std::max(int(info.table_dim), block_size);
bool is_gpt_j = info.algo == infiniopRoPEAlgo_t::INFINIOP_ROPE_ALGO_GPT_J;
// Calculate position ID stride for batch dimension
size_t pos_stride_batch = info.pos_has_batch_dim ? info.seqlen : 0;
dim3 grid_dim;
if (info.has_batch_dim) {
// 4D tensors: use 3D grid [seqlen, nhead, batch]
grid_dim = dim3(dimx, dimy, dimz);
} else {
// 3D tensors: use 2D grid [seqlen, nhead], batch dimension is 1
grid_dim = dim3(dimx, dimy);
}
if (is_gpt_j) {
ropeThreadPerItemKernel<true><<<dim3(dimx, dimy), nthreads, 0, stream>>>(
ropeThreadPerItemKernel<true><<<grid_dim, nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
pos_stride_batch,
info.pos_has_batch_dim,
info.has_batch_dim,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
} else {
ropeThreadPerItemKernel<false><<<dim3(dimx, dimy), nthreads, 0, stream>>>(
ropeThreadPerItemKernel<false><<<grid_dim, nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
pos_stride_batch,
info.pos_has_batch_dim,
info.has_batch_dim,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
}
return INFINI_STATUS_SUCCESS;
......
......@@ -59,12 +59,16 @@ private:
public:
infiniDtype_t data_type, pos_type;
size_t seqlen, nhead, dhead, table_len, table_dim;
size_t batch, seqlen, nhead, dhead, table_len, table_dim;
ptrdiff_t
y_stride_batch, // Batch stride (0 for 3D tensors)
y_stride_seqlen,
y_stride_nhead,
x_stride_batch, // Batch stride (0 for 3D tensors)
x_stride_seqlen,
x_stride_nhead;
bool has_batch_dim; // Whether tensors have batch dimension
bool pos_has_batch_dim; // Whether position IDs have batch dimension
infiniopRoPEAlgo_t algo;
static utils::Result<RoPEInfo>
......@@ -86,43 +90,114 @@ public:
CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE_ANY_INT(pos_type);
CHECK_OR_RETURN(y_desc->ndim() == 3
&& x_desc->ndim() == 3
&& pos_desc->ndim() == 1
&& sin_desc->ndim() == 2
&& cos_desc->ndim() == 2,
// Support both 3D (no batch) and 4D (with batch) tensors
bool y_has_batch = y_desc->ndim() == 4;
bool x_has_batch = x_desc->ndim() == 4;
CHECK_OR_RETURN(y_has_batch == x_has_batch, INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN((y_has_batch && x_has_batch) || (y_desc->ndim() == 3 && x_desc->ndim() == 3), INFINI_STATUS_BAD_TENSOR_SHAPE);
// Check position IDs: can be 1D (shared) or 2D (per-batch)
bool pos_has_batch = pos_desc->ndim() == 2;
CHECK_OR_RETURN(pos_desc->ndim() == 1 || pos_desc->ndim() == 2,
INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(sin_desc->ndim() == 2 && cos_desc->ndim() == 2,
INFINI_STATUS_BAD_TENSOR_SHAPE);
const auto seqlen = y_desc->dim(0),
nhead = y_desc->dim(1),
dhead = y_desc->dim(2),
table_len = sin_desc->dim(0),
table_dim = sin_desc->dim(1);
size_t batch, seqlen, nhead, dhead;
if (y_has_batch) {
// 4D tensors: [batch, seqlen, nhead, dhead]
batch = y_desc->dim(0);
seqlen = y_desc->dim(1);
nhead = y_desc->dim(2);
dhead = y_desc->dim(3);
CHECK_OR_RETURN(seqlen == x_desc->dim(0)
&& seqlen == pos_desc->dim(0)
&& nhead == x_desc->dim(1) && dhead == x_desc->dim(2)
&& table_len == cos_desc->dim(0) && table_dim == cos_desc->dim(1),
CHECK_OR_RETURN(batch == x_desc->dim(0) && seqlen == x_desc->dim(1) && nhead == x_desc->dim(2) && dhead == x_desc->dim(3),
INFINI_STATUS_BAD_TENSOR_SHAPE);
} else {
// 3D tensors: [seqlen, nhead, dhead] (batch = 1)
batch = 1;
seqlen = y_desc->dim(0);
nhead = y_desc->dim(1);
dhead = y_desc->dim(2);
CHECK_OR_RETURN(seqlen == x_desc->dim(0) && nhead == x_desc->dim(1) && dhead == x_desc->dim(2),
INFINI_STATUS_BAD_TENSOR_SHAPE);
}
const auto table_len = sin_desc->dim(0);
const auto table_dim = sin_desc->dim(1);
// Check position IDs shape
if (pos_has_batch) {
// 2D position IDs: [batch, seqlen] or [batch, seqlen?]
CHECK_OR_RETURN(batch == pos_desc->dim(0) && seqlen == pos_desc->dim(1),
INFINI_STATUS_BAD_TENSOR_SHAPE);
} else {
// 1D position IDs: [seqlen]
CHECK_OR_RETURN(seqlen == pos_desc->dim(0),
INFINI_STATUS_BAD_TENSOR_SHAPE);
}
CHECK_OR_RETURN(table_len == cos_desc->dim(0) && table_dim == cos_desc->dim(1),
INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(dhead == table_dim * 2, INFINI_STATUS_BAD_TENSOR_SHAPE);
// Last dimension of x and y must be contiguous
CHECK_OR_RETURN(y_desc->stride(2) == 1 && x_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
if (y_has_batch) {
CHECK_OR_RETURN(y_desc->stride(3) == 1 && x_desc->stride(3) == 1,
INFINI_STATUS_BAD_TENSOR_STRIDES);
} else {
CHECK_OR_RETURN(y_desc->stride(2) == 1 && x_desc->stride(2) == 1,
INFINI_STATUS_BAD_TENSOR_STRIDES);
}
// sin table and cos table must be totally contiguous
CHECK_OR_RETURN(sin_desc->isContiguous() && cos_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(sin_desc->isContiguous() && cos_desc->isContiguous(),
INFINI_STATUS_BAD_TENSOR_STRIDES);
// Set strides based on tensor dimensions
ptrdiff_t y_stride_batch, y_stride_seqlen, y_stride_nhead;
ptrdiff_t x_stride_batch, x_stride_seqlen, x_stride_nhead;
if (y_has_batch) {
y_stride_batch = y_desc->stride(0);
y_stride_seqlen = y_desc->stride(1);
y_stride_nhead = y_desc->stride(2);
x_stride_batch = x_desc->stride(0);
x_stride_seqlen = x_desc->stride(1);
x_stride_nhead = x_desc->stride(2);
} else {
// For 3D tensors, set batch stride to 0 (no batch dimension)
y_stride_batch = 0;
y_stride_seqlen = y_desc->stride(0);
y_stride_nhead = y_desc->stride(1);
x_stride_batch = 0;
x_stride_seqlen = x_desc->stride(0);
x_stride_nhead = x_desc->stride(1);
}
return utils::Result<RoPEInfo>(RoPEInfo{
data_type,
pos_type,
batch,
seqlen,
nhead,
dhead,
table_len,
table_dim,
y_desc->stride(0),
y_desc->stride(1),
x_desc->stride(0),
x_desc->stride(1),
y_stride_batch,
y_stride_seqlen,
y_stride_nhead,
x_stride_batch,
x_stride_seqlen,
x_stride_nhead,
y_has_batch, // has_batch_dim
pos_has_batch, // pos_has_batch_dim
algo,
});
}
......
......@@ -33,6 +33,10 @@ _TEST_CASES_ = [
((4, 1, 32), (64, 64, 1), None),
((11, 33, 128), None, (8000, 200, 1)),
((3, 32, 128), (8000, 200, 1), (7000, 128, 1)),
((8, 1, 32, 128), None, None),
((8, 10, 32, 64), None, None),
((8, 20, 32, 64), (40960, 64, 1280, 1), (40960, 64, 1280, 1)),
((8, 20, 4, 64), (1048576, 64, 262144, 1), (1048576, 64, 262144, 1)),
]
# Data types used for testing
......@@ -153,9 +157,9 @@ def test(
f"Testing Rotary Positional Embedding on {InfiniDeviceNames[device]} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{InfiniDtypeNames[dtype]} inplace:{inplace} algo:{algo}"
)
theta = 1e5
pos = TestTensor.from_torch(torch.arange(0, x.shape[0]), InfiniDtype.I32, device)
pos = TestTensor.from_torch(torch.arange(0, x.shape[-3]), InfiniDtype.I32, device)
sin_table, cos_table = sin_cos_table(
pos.torch_tensor(), x.shape[2], x.device, theta, dtype
pos.torch_tensor(), x.shape[-1], x.device, theta, dtype
)
rotary_embedding(
......
......@@ -47,7 +47,7 @@ target("infiniop-metax")
on_install(function (target) end)
set_languages("cxx17")
set_warnings("all", "error")
add_cxflags("-lstdc++", "-fPIC", "-Wno-defaulted-function-deleted", "-Wno-strict-aliasing")
add_cxflags("-lstdc++", "-fPIC", "-Wno-defaulted-function-deleted", "-Wno-strict-aliasing", {force = true})
add_files("../src/infiniop/devices/metax/*.cc", "../src/infiniop/ops/*/metax/*.cc")
add_files("../src/infiniop/ops/*/metax/*.maca", {rule = "maca"})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment