Unverified Commit 8d09630a authored by gongchensu's avatar gongchensu Committed by GitHub
Browse files

Merge branch 'demo131' into Issue/862

parents ab52dead 012df56c
#include "../../../../devices/nvidia/nvidia_common.cuh"
#include "per_channel_quant_int8_nvidia.cuh"
#include "../../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../../../../reduce/cuda/reduce.cuh"
#include <cub/block/block_reduce.cuh>
#include "../cuda/kernel.cuh"
template <typename Tdata, unsigned int BLOCK_SIZE>
INFINIOP_CUDA_KERNEL blockPerChannelQuantI8(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, int M, int K) {
blockPerChannelQuantI8Kernel<Tdata, BLOCK_SIZE>(x_packed, x_scale, x_zero, x, M, K);
}
template <typename Tdata, unsigned int BLOCK_SIZE>
INFINIOP_CUDA_KERNEL blockPerChannelQuantI8Sym(
int8_t *x_packed, float *x_scale, const Tdata *x, int M, int K) {
blockPerChannelQuantI8SymKernel<Tdata, BLOCK_SIZE>(x_packed, x_scale, x, M, K);
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
INFINIOP_CUDA_KERNEL warpPerChannelQuantI8(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, int M, int K) {
warpPerChannelQuantI8Kernel<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>(x_packed, x_scale, x_zero, x, M, K);
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
INFINIOP_CUDA_KERNEL warpPerChannelQuantI8Sym(
int8_t *x_packed, float *x_scale, const Tdata *x, int M, int K) {
warpPerChannelQuantI8SymKernel<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>(x_packed, x_scale, x, M, K);
}
namespace op::per_channel_quant_int8::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle, Descriptor **desc_ptr,
infiniopTensorDescriptor_t x_packed_desc,
infiniopTensorDescriptor_t x_scale_desc,
infiniopTensorDescriptor_t x_zero_desc,
infiniopTensorDescriptor_t x_desc) {
auto info = PerChannelQuantI8Info::createPerChannelQuantI8Info(x_packed_desc, x_scale_desc, x_zero_desc, x_desc);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE, typename Tdata>
infiniStatus_t per_channel_quant_int8Kernel(const PerChannelQuantI8Info &info, int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, cudaStream_t stream) {
int M = (int)info.M;
int K = (int)info.K;
if (K >= 1024) {
if (x_zero == nullptr) {
blockPerChannelQuantI8Sym<Tdata, BLOCK_SIZE>
<<<M, BLOCK_SIZE, 0, stream>>>(x_packed, x_scale, x, M, K);
} else {
blockPerChannelQuantI8<Tdata, BLOCK_SIZE>
<<<M, BLOCK_SIZE, 0, stream>>>(x_packed, x_scale, x_zero, x, M, K);
}
} else {
constexpr unsigned int BLOCK_SIZE_x = 32;
constexpr unsigned int BLOCK_SIZE_y = 32;
int num_block_x = (M + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
if (x_zero == nullptr) {
warpPerChannelQuantI8Sym<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>
<<<grid_dim, block_dim, 0, stream>>>(x_packed, x_scale, x, M, K);
} else {
warpPerChannelQuantI8<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>
<<<grid_dim, block_dim, 0, stream>>>(x_packed, x_scale, x_zero, x, M, K);
}
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *x_packed, void *x_scale, void *x_zero, const void *x,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
#define QUANT(BLOCK_SIZE, TDATA) \
per_channel_quant_int8Kernel<BLOCK_SIZE, TDATA>(_info, (int8_t *)x_packed, (float *)x_scale, (float *)x_zero, (const TDATA *)x, stream)
#define QUANT_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return QUANT(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return QUANT(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return QUANT(BLOCK_SIZE, __nv_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
QUANT_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
QUANT_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
QUANT_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::per_channel_quant_int8::nvidia
#ifndef __PER_CHANNEL_QUANT_INT8_NVIDIA_API_H__
#define __PER_CHANNEL_QUANT_INT8_NVIDIA_API_H__
#include "../per_channel_quant_int8.h"
DESCRIPTOR(nvidia)
#endif // __PER_CHANNEL_QUANT_INT8_NVIDIA_API_H__
#include "../../../operator.h"
#include "../../../handle.h"
#include "infiniop/ops/quant/per_channel_quant_int8.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/per_channel_quant_int8_nvidia.cuh"
#endif
__C infiniStatus_t infiniopCreatePerChannelQuantI8Descriptor(infiniopHandle_t handle,
infiniopPerChannelQuantI8Descriptor_t *desc_ptr,
infiniopTensorDescriptor_t x_packed_desc,
infiniopTensorDescriptor_t x_scale_desc,
infiniopTensorDescriptor_t x_zero_desc,
infiniopTensorDescriptor_t x_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::per_channel_quant_int8::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::per_channel_quant_int8::NAMESPACE::Descriptor **>(desc_ptr), \
x_packed_desc, \
x_scale_desc, \
x_zero_desc, \
x_desc);
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetPerChannelQuantI8WorkspaceSize(infiniopPerChannelQuantI8Descriptor_t desc, size_t *size) {
switch (desc->device_type) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::per_channel_quant_int8::NAMESPACE::Descriptor *>(desc)->minWorkspaceSize(); \
return INFINI_STATUS_SUCCESS;
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
}
__C infiniStatus_t infiniopPerChannelQuantI8(infiniopPerChannelQuantI8Descriptor_t desc,
void *workspace,
size_t workspace_size,
void *x_packed,
void *x_scale,
void *x_zero,
const void *x,
void *stream) {
#define QUANT(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::per_channel_quant_int8::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, x_packed, x_scale, x_zero, x, stream);
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
QUANT(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_QY_API
QUANT(INFINI_DEVICE_QY, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef QUANT
}
__C infiniStatus_t infiniopDestroyPerChannelQuantI8Descriptor(infiniopPerChannelQuantI8Descriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::per_channel_quant_int8::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DESTROY
}
#ifndef __QUANT_H__
#define __QUANT_H__
#include "../../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::per_channel_quant_int8::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PerChannelQuantI8Info _info; \
size_t _workspace_size; \
\
Descriptor(Opaque *opaque, PerChannelQuantI8Info info, \
size_t workspace_size, \
infiniDevice_t device_type, int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), _info(info), _workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t minWorkspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, Descriptor **desc_ptr, \
infiniopTensorDescriptor_t x_packed_desc, \
infiniopTensorDescriptor_t x_scale_desc, \
infiniopTensorDescriptor_t x_zero_desc, \
infiniopTensorDescriptor_t x_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *x_packed, void *x_scale, void *x_zero, const void *x, void *stream) const; \
}; \
}
#endif // __QUANT_H__
...@@ -534,13 +534,13 @@ struct Algo { ...@@ -534,13 +534,13 @@ struct Algo {
if constexpr (std::is_same<Tval_, float>::value) { if constexpr (std::is_same<Tval_, float>::value) {
auto logits = reinterpret_cast<const float *>(probs); auto logits = reinterpret_cast<const float *>(probs);
argMax<<<dim, CNRT_FUNC_TYPE_BLOCK, queue>>>(logits, result, gdram_indices, voc); argMax<<<dim, cnrtFuncTypeBlock, queue>>>(logits, result, gdram_indices, voc);
} else if constexpr (std::is_same<Tval_, CustomFloat16>::value) { } else if constexpr (std::is_same<Tval_, CustomFloat16>::value) {
auto logits = reinterpret_cast<const half *>(probs); auto logits = reinterpret_cast<const half *>(probs);
argMax<<<dim, CNRT_FUNC_TYPE_BLOCK, queue>>>(logits, result, gdram_indices, voc); argMax<<<dim, cnrtFuncTypeBlock, queue>>>(logits, result, gdram_indices, voc);
} else if constexpr (std::is_same<Tval_, CustomBFloat16>::value) { } else if constexpr (std::is_same<Tval_, CustomBFloat16>::value) {
auto logits = reinterpret_cast<const bfloat16_t *>(probs); auto logits = reinterpret_cast<const bfloat16_t *>(probs);
argMax<<<dim, CNRT_FUNC_TYPE_BLOCK, queue>>>(logits, result, gdram_indices, voc); argMax<<<dim, cnrtFuncTypeBlock, queue>>>(logits, result, gdram_indices, voc);
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
...@@ -575,10 +575,10 @@ struct Algo { ...@@ -575,10 +575,10 @@ struct Algo {
const int max_num = SRC_MAX_SIZE / sizeof(float); const int max_num = SRC_MAX_SIZE / sizeof(float);
if (voc >= task_num * max_num) { if (voc >= task_num * max_num) {
randomSampleKernelLarge<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>( randomSampleKernelLarge<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature); logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} else { } else {
randomSampleKernel<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>( randomSampleKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature); logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} }
} else if constexpr (std::is_same<Tval_, CustomFloat16>::value) { } else if constexpr (std::is_same<Tval_, CustomFloat16>::value) {
...@@ -592,10 +592,10 @@ struct Algo { ...@@ -592,10 +592,10 @@ struct Algo {
const int max_num = SRC_MAX_SIZE / sizeof(half); const int max_num = SRC_MAX_SIZE / sizeof(half);
if (voc >= task_num * max_num) { if (voc >= task_num * max_num) {
randomSampleKernelLarge<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>( randomSampleKernelLarge<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature); logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} else { } else {
randomSampleKernel<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>( randomSampleKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature); logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} }
} else if constexpr (std::is_same<Tval_, CustomBFloat16>::value) { } else if constexpr (std::is_same<Tval_, CustomBFloat16>::value) {
...@@ -609,10 +609,10 @@ struct Algo { ...@@ -609,10 +609,10 @@ struct Algo {
const int max_num = SRC_MAX_SIZE / sizeof(bfloat16_t); const int max_num = SRC_MAX_SIZE / sizeof(bfloat16_t);
if (voc >= task_num * max_num) { if (voc >= task_num * max_num) {
randomSampleKernelLarge<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>( randomSampleKernelLarge<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature); logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} else { } else {
randomSampleKernel<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>( randomSampleKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature); logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} }
} else { } else {
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
#include "cpu/random_sample_cpu.h" #include "cpu/random_sample_cpu.h"
#endif #endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API)
#include "nvidia/random_sample_nvidia.cuh" #include "nvidia/random_sample_nvidia.cuh"
#endif #endif
#ifdef ENABLE_CAMBRICON_API #ifdef ENABLE_CAMBRICON_API
...@@ -50,6 +50,9 @@ infiniopCreateRandomSampleDescriptor( ...@@ -50,6 +50,9 @@ infiniopCreateRandomSampleDescriptor(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia); CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia); CREATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
...@@ -101,6 +104,9 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize( ...@@ -101,6 +104,9 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia); GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia); GET(INFINI_DEVICE_QY, nvidia);
#endif #endif
...@@ -165,6 +171,9 @@ __C infiniStatus_t infiniopRandomSample( ...@@ -165,6 +171,9 @@ __C infiniStatus_t infiniopRandomSample(
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia); CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_HYGON_API #ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia); CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif #endif
...@@ -210,6 +219,9 @@ __C infiniStatus_t infiniopDestroyRandomSampleDescriptor( ...@@ -210,6 +219,9 @@ __C infiniStatus_t infiniopDestroyRandomSampleDescriptor(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia); DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia); DELETE(INFINI_DEVICE_QY, nvidia);
#endif #endif
......
...@@ -267,7 +267,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -267,7 +267,7 @@ infiniStatus_t Descriptor::calculate(
dim.x = 4; // Using 4 clusters dim.x = 4; // Using 4 clusters
dim.y = 10; dim.y = 10;
dim.z = 1; dim.z = 1;
func_type = CNRT_FUNC_TYPE_UNION1; func_type = cnrtFuncTypeUnion1;
if (_opaque->use_2d_copy) { if (_opaque->use_2d_copy) {
// Use optimized 2D copy kernel // Use optimized 2D copy kernel
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#define ARRAY_TYPE_STRIDE ptrdiff_t #define ARRAY_TYPE_STRIDE ptrdiff_t
#define ARRAY_TYPE_SIZE size_t #define ARRAY_TYPE_SIZE size_t
// 与 DEFINE_KERNELS_BY_CONSTRAINT 耦合,需要同时修改
#define MAX_BLOCK_ARRAY_SIZE 5 #define MAX_BLOCK_ARRAY_SIZE 5
#define MAX_GRID_ARRAY_SIZE 5 #define MAX_GRID_ARRAY_SIZE 5
...@@ -16,7 +15,6 @@ struct ArrayStruct { ...@@ -16,7 +15,6 @@ struct ArrayStruct {
ArrayType a[ArrSize]; ArrayType a[ArrSize];
}; };
// 各个元素分别代表:[grid_idx, block_idx, grid的stride相对于block的倍数,总的len限制]
template <typename ElementType> template <typename ElementType>
struct Constraint { struct Constraint {
ElementType grid_idx; ElementType grid_idx;
...@@ -29,9 +27,8 @@ struct Constraint { ...@@ -29,9 +27,8 @@ struct Constraint {
#define IF_CONSTRAINT_1 , const ArrayStruct<1, Constraint<ARRAY_TYPE_SIZE>> constraints #define IF_CONSTRAINT_1 , const ArrayStruct<1, Constraint<ARRAY_TYPE_SIZE>> constraints
#define IF_CONSTRAINT_2 , const ArrayStruct<2, Constraint<ARRAY_TYPE_SIZE>> constraints #define IF_CONSTRAINT_2 , const ArrayStruct<2, Constraint<ARRAY_TYPE_SIZE>> constraints
// 定义宏生成内核函数
#define DEFINE_REARRANGE_KERNEL(Tmem_type, constraint_num, block_array_size, grid_array_size) \ #define DEFINE_REARRANGE_KERNEL(Tmem_type, constraint_num, block_array_size, grid_array_size) \
extern "C" __global__ void rearrange_unit_##Tmem_type##_block_##block_array_size##_grid_##grid_array_size##_constrain_##constraint_num( \ extern "C" INFINIOP_MOORE_KERNEL rearrange_unit_##Tmem_type##_block_##block_array_size##_grid_##grid_array_size##_constrain_##constraint_num( \
void *__restrict__ dst, \ void *__restrict__ dst, \
const void *__restrict__ src, \ const void *__restrict__ src, \
const size_t block_dim, \ const size_t block_dim, \
...@@ -48,15 +45,14 @@ struct Constraint { ...@@ -48,15 +45,14 @@ struct Constraint {
return; \ return; \
} \ } \
\ \
/* 声明共享内存 */ \
__shared__ ptrdiff_t shared_src_offset; \ __shared__ ptrdiff_t shared_src_offset; \
__shared__ ptrdiff_t shared_dst_offset; \ __shared__ ptrdiff_t shared_dst_offset; \
\ \
if (constraint_num > 0) { \ if (constraint_num > 0) { \
__shared__ ARRAY_TYPE_SIZE shared_constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \ __shared__ ARRAY_TYPE_SIZE shared_constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
\ \
if (threadIdx.x == 0) { /* 只让0号线程计算 */ \ if (threadIdx.x == 0) { \
/* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */ \ \
ptrdiff_t src_offset = 0; \ ptrdiff_t src_offset = 0; \
ptrdiff_t dst_offset = 0; \ ptrdiff_t dst_offset = 0; \
ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \ ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
...@@ -64,13 +60,13 @@ struct Constraint { ...@@ -64,13 +60,13 @@ struct Constraint {
size_t remaining \ size_t remaining \
= blockIdx.x; \ = blockIdx.x; \
\ \
for (ssize_t i = grid_array_size - 1; i >= 0; i--) { \ for (ptrdiff_t i = grid_array_size - 1; i >= 0; i--) { \
size_t idx = remaining % grid_len.a[i]; \ size_t idx = remaining % grid_len.a[i]; \
remaining /= grid_len.a[i]; \ remaining /= grid_len.a[i]; \
src_offset += idx * src_grid_stride.a[i]; \ src_offset += idx * src_grid_stride.a[i]; \
dst_offset += idx * dst_grid_stride.a[i]; \ dst_offset += idx * dst_grid_stride.a[i]; \
if (constraint_num > 0) { \ if (constraint_num > 0) { \
for (ssize_t j = 0; j < constraint_num; j++) { \ for (ptrdiff_t j = 0; j < constraint_num; j++) { \
if (i == constraints.a[j].grid_idx) { \ if (i == constraints.a[j].grid_idx) { \
constraints_grid_idx_multiple[j] = idx * constraints.a[j].grid_div_block; \ constraints_grid_idx_multiple[j] = idx * constraints.a[j].grid_div_block; \
} \ } \
...@@ -78,33 +74,30 @@ struct Constraint { ...@@ -78,33 +74,30 @@ struct Constraint {
} \ } \
} \ } \
\ \
/* 将结果存入共享内存 */ \
shared_src_offset = src_offset; \ shared_src_offset = src_offset; \
shared_dst_offset = dst_offset; \ shared_dst_offset = dst_offset; \
for (ssize_t j = 0; j < constraint_num; j++) { \ for (ptrdiff_t j = 0; j < constraint_num; j++) { \
shared_constraints_grid_idx_multiple[j] = constraints_grid_idx_multiple[j]; \ shared_constraints_grid_idx_multiple[j] = constraints_grid_idx_multiple[j]; \
} \ } \
} \ } \
\ \
/* 确保所有线程都能看到共享内存中的值 */ \
__syncthreads(); \ __syncthreads(); \
\ \
/* 所有线程直接使用计算好的偏移值 */ \
ptrdiff_t src_offset = shared_src_offset; \ ptrdiff_t src_offset = shared_src_offset; \
ptrdiff_t dst_offset = shared_dst_offset; \ ptrdiff_t dst_offset = shared_dst_offset; \
ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \ ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
for (ssize_t j = 0; j < constraint_num; j++) { \ for (ptrdiff_t j = 0; j < constraint_num; j++) { \
constraints_grid_idx_multiple[j] = shared_constraints_grid_idx_multiple[j]; \ constraints_grid_idx_multiple[j] = shared_constraints_grid_idx_multiple[j]; \
} \ } \
\ \
for (ssize_t i = block_array_size - 1; i >= 0; i--) { \ for (ptrdiff_t i = block_array_size - 1; i >= 0; i--) { \
size_t idx = remaining % block_len.a[i]; \ size_t idx = remaining % block_len.a[i]; \
remaining /= block_len.a[i]; \ remaining /= block_len.a[i]; \
/* 计算偏移量 */ \ \
src_offset += idx * src_block_stride.a[i]; \ src_offset += idx * src_block_stride.a[i]; \
dst_offset += idx * dst_block_stride.a[i]; \ dst_offset += idx * dst_block_stride.a[i]; \
if (constraint_num > 0) { \ if (constraint_num > 0) { \
for (ssize_t j = 0; j < constraint_num; j++) { \ for (ptrdiff_t j = 0; j < constraint_num; j++) { \
if (i == constraints.a[j].block_idx) { \ if (i == constraints.a[j].block_idx) { \
if (constraints_grid_idx_multiple[j] + idx >= constraints.a[j].total_len) { \ if (constraints_grid_idx_multiple[j] + idx >= constraints.a[j].total_len) { \
return; \ return; \
...@@ -116,7 +109,7 @@ struct Constraint { ...@@ -116,7 +109,7 @@ struct Constraint {
\ \
src_offset += remaining * src_block_stride.a[0]; \ src_offset += remaining * src_block_stride.a[0]; \
dst_offset += remaining * dst_block_stride.a[0]; \ dst_offset += remaining * dst_block_stride.a[0]; \
for (ssize_t j = 0; j < constraint_num; j++) { \ for (ptrdiff_t j = 0; j < constraint_num; j++) { \
if (0 == constraints.a[j].block_idx) { \ if (0 == constraints.a[j].block_idx) { \
if (constraints_grid_idx_multiple[j] + remaining >= constraints.a[j].total_len) { \ if (constraints_grid_idx_multiple[j] + remaining >= constraints.a[j].total_len) { \
return; \ return; \
...@@ -124,39 +117,35 @@ struct Constraint { ...@@ -124,39 +117,35 @@ struct Constraint {
} \ } \
} \ } \
\ \
/* 执行数据拷贝,注意offset已经是字节偏移 */ \
*reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \ *reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \
\ \
} else { \ } else { \
if (threadIdx.x == 0) { /* 只让0号线程计算 */ \ if (threadIdx.x == 0) { \
/* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */ \ \
ptrdiff_t src_offset = 0; \ ptrdiff_t src_offset = 0; \
ptrdiff_t dst_offset = 0; \ ptrdiff_t dst_offset = 0; \
size_t remaining = blockIdx.x; \ size_t remaining = blockIdx.x; \
\ \
for (ssize_t i = grid_array_size - 1; i >= 0; i--) { \ for (ptrdiff_t i = grid_array_size - 1; i >= 0; i--) { \
size_t idx = remaining % grid_len.a[i]; \ size_t idx = remaining % grid_len.a[i]; \
remaining /= grid_len.a[i]; \ remaining /= grid_len.a[i]; \
src_offset += idx * src_grid_stride.a[i]; \ src_offset += idx * src_grid_stride.a[i]; \
dst_offset += idx * dst_grid_stride.a[i]; \ dst_offset += idx * dst_grid_stride.a[i]; \
} \ } \
\ \
/* 将结果存入共享内存 */ \
shared_src_offset = src_offset; \ shared_src_offset = src_offset; \
shared_dst_offset = dst_offset; \ shared_dst_offset = dst_offset; \
} \ } \
\ \
/* 确保所有线程都能看到共享内存中的值 */ \
__syncthreads(); \ __syncthreads(); \
\ \
/* 所有线程直接使用计算好的偏移值 */ \
ptrdiff_t src_offset = shared_src_offset; \ ptrdiff_t src_offset = shared_src_offset; \
ptrdiff_t dst_offset = shared_dst_offset; \ ptrdiff_t dst_offset = shared_dst_offset; \
\ \
for (ssize_t i = block_array_size - 1; i > 0; i--) { \ for (ptrdiff_t i = block_array_size - 1; i > 0; i--) { \
size_t idx = remaining % block_len.a[i]; \ size_t idx = remaining % block_len.a[i]; \
remaining /= block_len.a[i]; \ remaining /= block_len.a[i]; \
/* 计算偏移量 */ \ \
src_offset += idx * src_block_stride.a[i]; \ src_offset += idx * src_block_stride.a[i]; \
dst_offset += idx * dst_block_stride.a[i]; \ dst_offset += idx * dst_block_stride.a[i]; \
} \ } \
...@@ -164,18 +153,15 @@ struct Constraint { ...@@ -164,18 +153,15 @@ struct Constraint {
src_offset += remaining * src_block_stride.a[0]; \ src_offset += remaining * src_block_stride.a[0]; \
dst_offset += remaining * dst_block_stride.a[0]; \ dst_offset += remaining * dst_block_stride.a[0]; \
\ \
/* 执行数据拷贝,注意offset已经是字节偏移 */ \
*reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \ *reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \
} \ } \
} }
// 定义支持的约束条件数量组合
#define DEFINE_KERNELS_BY_CONSTRAINT(block_array_size, grid_array_size) \ #define DEFINE_KERNELS_BY_CONSTRAINT(block_array_size, grid_array_size) \
DEFINE_KERNELS_BY_TYPE(0, block_array_size, grid_array_size) \ DEFINE_KERNELS_BY_TYPE(0, block_array_size, grid_array_size) \
DEFINE_KERNELS_BY_TYPE(1, block_array_size, grid_array_size) \ DEFINE_KERNELS_BY_TYPE(1, block_array_size, grid_array_size) \
DEFINE_KERNELS_BY_TYPE(2, block_array_size, grid_array_size) DEFINE_KERNELS_BY_TYPE(2, block_array_size, grid_array_size)
// 定义支持的类型
#define DEFINE_KERNELS_BY_TYPE(constraint_num, block_array_size, grid_array_size) \ #define DEFINE_KERNELS_BY_TYPE(constraint_num, block_array_size, grid_array_size) \
DEFINE_REARRANGE_KERNEL(uchar1, constraint_num, block_array_size, grid_array_size) \ DEFINE_REARRANGE_KERNEL(uchar1, constraint_num, block_array_size, grid_array_size) \
DEFINE_REARRANGE_KERNEL(uchar2, constraint_num, block_array_size, grid_array_size) \ DEFINE_REARRANGE_KERNEL(uchar2, constraint_num, block_array_size, grid_array_size) \
...@@ -184,8 +170,6 @@ struct Constraint { ...@@ -184,8 +170,6 @@ struct Constraint {
DEFINE_REARRANGE_KERNEL(float4, constraint_num, block_array_size, grid_array_size) \ DEFINE_REARRANGE_KERNEL(float4, constraint_num, block_array_size, grid_array_size) \
DEFINE_REARRANGE_KERNEL(double4, constraint_num, block_array_size, grid_array_size) DEFINE_REARRANGE_KERNEL(double4, constraint_num, block_array_size, grid_array_size)
// 与 MAX_BLOCK_ARRAY_SIZE 和 MAX_GRID_ARRAY_SIZE 耦合,需要同时修改
// 为1-5和1-5的所有组合生成内核
DEFINE_KERNELS_BY_CONSTRAINT(1, 1) DEFINE_KERNELS_BY_CONSTRAINT(1, 1)
DEFINE_KERNELS_BY_CONSTRAINT(1, 2) DEFINE_KERNELS_BY_CONSTRAINT(1, 2)
DEFINE_KERNELS_BY_CONSTRAINT(1, 3) DEFINE_KERNELS_BY_CONSTRAINT(1, 3)
...@@ -212,7 +196,6 @@ DEFINE_KERNELS_BY_CONSTRAINT(5, 3) ...@@ -212,7 +196,6 @@ DEFINE_KERNELS_BY_CONSTRAINT(5, 3)
DEFINE_KERNELS_BY_CONSTRAINT(5, 4) DEFINE_KERNELS_BY_CONSTRAINT(5, 4)
DEFINE_KERNELS_BY_CONSTRAINT(5, 5) DEFINE_KERNELS_BY_CONSTRAINT(5, 5)
// 准备参数结构体
struct RearrangeParams { struct RearrangeParams {
std::vector<ARRAY_TYPE_SIZE> block_len; std::vector<ARRAY_TYPE_SIZE> block_len;
std::vector<ARRAY_TYPE_STRIDE> src_block_stride; std::vector<ARRAY_TYPE_STRIDE> src_block_stride;
...@@ -234,25 +217,8 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams &params) { ...@@ -234,25 +217,8 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams &params) {
CHECK_OR_RETURN(grid_num <= MAX_GRID_ARRAY_SIZE && grid_num != 0, INFINI_STATUS_BAD_PARAM); CHECK_OR_RETURN(grid_num <= MAX_GRID_ARRAY_SIZE && grid_num != 0, INFINI_STATUS_BAD_PARAM);
CHECK_OR_RETURN(block_num <= MAX_BLOCK_ARRAY_SIZE && block_num != 0, INFINI_STATUS_BAD_PARAM); CHECK_OR_RETURN(block_num <= MAX_BLOCK_ARRAY_SIZE && block_num != 0, INFINI_STATUS_BAD_PARAM);
CHECK_OR_RETURN(constraint_num <= 2, INFINI_STATUS_BAD_PARAM); CHECK_OR_RETURN(constraint_num <= 2, INFINI_STATUS_BAD_PARAM);
/*
* These variables were originally part of the CUDA implementation for this kernel.
* They have been commented out because they are not currently used in the MUSA kernel logic.
*
* This change resolves "unused variable" warnings during compilation, ensuring a clean build.
* The original declarations are preserved here for for MUSA/CUDA platform alignment.
*/
// auto block_len = params.block_len.data();
// auto src_block_stride = params.src_block_stride.data();
// auto dst_block_stride = params.dst_block_stride.data();
// auto grid_len = params.grid_len.data();
// auto src_grid_stride = params.src_grid_stride.data();
// auto dst_grid_stride = params.dst_grid_stride.data();
// auto constrain = params.constraints.data();
void *kernel_func = nullptr; void *kernel_func = nullptr;
#define GET_REARRANGE_KERNEL(Tmem_type, block_array_size, grid_array_size, constraint_num) \ #define GET_REARRANGE_KERNEL(Tmem_type, block_array_size, grid_array_size, constraint_num) \
kernel_func = (void *)rearrange_unit_##Tmem_type##_block_##block_array_size##_grid_##grid_array_size##_constrain_##constraint_num; kernel_func = (void *)rearrange_unit_##Tmem_type##_block_##block_array_size##_grid_##grid_array_size##_constrain_##constraint_num;
......
...@@ -28,7 +28,7 @@ infiniStatus_t Descriptor::create( ...@@ -28,7 +28,7 @@ infiniStatus_t Descriptor::create(
CHECK_OR_RETURN(x_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); CHECK_OR_RETURN(x_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_OR_RETURN(x_desc->ndim() == ndim, INFINI_STATUS_BAD_TENSOR_SHAPE); CHECK_OR_RETURN(x_desc->ndim() == ndim, INFINI_STATUS_BAD_TENSOR_SHAPE);
// 保存临时vector对象
auto x_shape = x_desc->shape(); auto x_shape = x_desc->shape();
auto y_shape = y_desc->shape(); auto y_shape = y_desc->shape();
auto y_strides = y_desc->strides(); auto y_strides = y_desc->strides();
...@@ -52,14 +52,12 @@ infiniStatus_t Descriptor::create( ...@@ -52,14 +52,12 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
// 维度信息结构
struct Dim { struct Dim {
size_t len; size_t len;
ARRAY_TYPE_STRIDE src_stride; ARRAY_TYPE_STRIDE src_stride;
ARRAY_TYPE_STRIDE dst_stride; ARRAY_TYPE_STRIDE dst_stride;
}; };
// 分割维度结构
struct SplitDim { struct SplitDim {
size_t choose_idx; size_t choose_idx;
size_t num_per_block; size_t num_per_block;
...@@ -69,28 +67,17 @@ struct SplitDim { ...@@ -69,28 +67,17 @@ struct SplitDim {
size_t dim_len; size_t dim_len;
}; };
/**
* 根据给定的元数据准备张量重排参数,该函数主要完成以下工作:
* 1. 根据原始元数据调整单元大小,获取更适合GPU处理的单元大小
* 2. 将维度分配为块(block)维度和网格(grid)维度:
* 该步骤是核心,目标是为每个block分配尽可能多的相对连续的数据进行处理,
* 对无法完整放入块的维度进行分割,并记录分割维度信息,用于防止kernel访问越界,最大化内存访问局部性和计算效率
*/
utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta &original_meta, int max_threads) { utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta &original_meta, int max_threads) {
RearrangeParams params; RearrangeParams params;
// 获取更适合GPU处理的单元大小,这里使用2的幂次方
auto meta_result = original_meta.distributeUnit({32, 16, 8, 4, 2, 1}); auto meta_result = original_meta.distributeUnit({32, 16, 8, 4, 2, 1});
CHECK_RESULT(meta_result); CHECK_RESULT(meta_result);
const utils::RearrangeMeta &meta = meta_result.take(); const utils::RearrangeMeta &meta = meta_result.take();
// 获取维度信息
const size_t ndim = meta.ndim(); const size_t ndim = meta.ndim();
const size_t unit = meta.unit(); const size_t unit = meta.unit();
// 特殊情况:无维度,只需要简单复制
if (ndim == 0) { if (ndim == 0) {
params.block_dim = 0; params.block_dim = 0;
params.block_len_total = 1; params.block_len_total = 1;
...@@ -104,12 +91,10 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta ...@@ -104,12 +91,10 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
return utils::Result<RearrangeParams>(params); return utils::Result<RearrangeParams>(params);
} }
// 从元数据中提取必要的信息
const ptrdiff_t *idx_strides = meta.idx_strides(); const ptrdiff_t *idx_strides = meta.idx_strides();
const ptrdiff_t *dst_strides = meta.dst_strides(); const ptrdiff_t *dst_strides = meta.dst_strides();
const ptrdiff_t *src_strides = meta.src_strides(); const ptrdiff_t *src_strides = meta.src_strides();
// 准备维度信息
std::vector<Dim> dims; std::vector<Dim> dims;
std::vector<size_t> shape; std::vector<size_t> shape;
dims.reserve(ndim); dims.reserve(ndim);
...@@ -123,153 +108,93 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta ...@@ -123,153 +108,93 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
prev_idx_stride = idx_strides[i]; prev_idx_stride = idx_strides[i];
} }
// 计算src_strides的降序排序索引,类似于Rust版本中的src_strides_desc_idx std::vector<bool> block_dim_choose(ndim, false);
std::vector<size_t> src_strides_desc_idx(ndim); std::vector<SplitDim> split_dims;
std::vector<size_t> dim_order(ndim);
for (size_t i = 0; i < ndim; ++i) { for (size_t i = 0; i < ndim; ++i) {
src_strides_desc_idx[i] = i; dim_order[i] = i;
} }
std::sort(src_strides_desc_idx.begin(), src_strides_desc_idx.end(),
std::sort(dim_order.begin(), dim_order.end(),
[&dims](size_t a, size_t b) { [&dims](size_t a, size_t b) {
return std::abs(dims[a].src_stride) > std::abs(dims[b].src_stride); return std::abs(dims[a].src_stride) < std::abs(dims[b].src_stride);
}); });
// 根据最大线程数选择block和grid维度 constexpr size_t MAX_BLOCK_DIM = MAX_BLOCK_ARRAY_SIZE;
const size_t block_size = max_threads;
std::vector<bool> block_dim_choose(ndim, false);
// 初始化计数器
size_t block_elements = 1; size_t block_elements = 1;
size_t block_src_elements = 1; size_t chosen_block_dims = 0;
size_t block_dst_elements = 1;
size_t src_choose_idx = ndim;
size_t dst_choose_idx = ndim;
// 用于存储分割维度信息
std::vector<SplitDim> split_dims;
// 维度选择循环 for (size_t i = 0; i < ndim; ++i) {
while (src_choose_idx > 0 && dst_choose_idx > 0) { size_t dim_idx = dim_order[i];
// 获取当前需要处理的维度索引 size_t dim_len = shape[dim_idx];
size_t src_idx = src_strides_desc_idx[src_choose_idx - 1];
size_t dst_idx = dst_choose_idx - 1; if (chosen_block_dims < MAX_BLOCK_DIM &&
block_elements * dim_len <= (size_t)max_threads) {
if (src_idx == dst_idx) {
// 源和目标维度相同,可以一起处理 block_dim_choose[dim_idx] = true;
size_t idx = src_idx; block_elements *= dim_len;
size_t len = shape[idx]; chosen_block_dims++;
continue;
// 检查是否可以将此维度完全添加到block中 }
if (block_elements * len <= block_size) {
// 选择此维度 if (block_elements > 1 && dim_len > 1) {
block_dim_choose[idx] = true;
block_elements *= len; if (chosen_block_dims + 1 > MAX_BLOCK_DIM) {
block_src_elements *= len; break;
block_dst_elements *= len;
src_choose_idx--;
dst_choose_idx--;
} else {
// 需要分割此维度
size_t num_per_block = block_size / block_elements;
// 确保num_per_block > 0且len >= num_per_block
if (num_per_block > 0 && len >= num_per_block && num_per_block > 1) {
size_t num_per_grid = (len + num_per_block - 1) / num_per_block; // 向上取整
SplitDim split_dim = {
idx, // choose_idx
num_per_block, // num_per_block
num_per_grid, // num_per_grid
0, // array_struct_idx_block (待更新)
0, // array_struct_idx_grid (待更新)
len // 原始维度长度
};
split_dims.push_back(split_dim);
}
break;
} }
} else {
// 源和目标维度不同,需要分别处理 size_t num_per_block =
// 计算块比例 std::min(dim_len, (size_t)max_threads / block_elements);
double src_div_dst = static_cast<double>(block_src_elements) / block_dst_elements;
double src_num_per_block = std::sqrt(block_size / (double)block_elements / src_div_dst); if (num_per_block > 0) {
double dst_num_per_block = src_num_per_block * src_div_dst; size_t num_per_grid =
(dim_len + num_per_block - 1) / num_per_block;
size_t src_current_dim_len = shape[src_idx];
size_t dst_current_dim_len = shape[dst_idx]; split_dims.push_back({
dim_idx,
if (static_cast<double>(src_current_dim_len) < src_num_per_block) { num_per_block,
// 源维度可以完全添加到block num_per_grid,
block_dim_choose[src_idx] = true; 0,
block_elements *= src_current_dim_len; 0,
block_src_elements *= src_current_dim_len; dim_len
src_choose_idx--; });
} else if (static_cast<double>(dst_current_dim_len) < dst_num_per_block) {
// 目标维度可以完全添加到block block_elements *= num_per_block;
block_dim_choose[dst_idx] = true; chosen_block_dims++;
block_elements *= dst_current_dim_len; }
block_dst_elements *= dst_current_dim_len; break;
dst_choose_idx--; }
} else { }
// 需要分割源和目标维度
size_t src_num_per_block_int = static_cast<size_t>(std::floor(src_num_per_block));
size_t dst_num_per_block_int = static_cast<size_t>(std::floor(dst_num_per_block));
// 计算网格尺寸
size_t src_num_per_grid = (src_current_dim_len + src_num_per_block_int - 1) / src_num_per_block_int; // 向上取整
size_t dst_num_per_grid = (dst_current_dim_len + dst_num_per_block_int - 1) / dst_num_per_block_int; // 向上取整
// 处理源维度
if (src_num_per_block_int > 1) {
if (src_num_per_grid == 1) {
// 可以完全放入块
block_dim_choose[src_idx] = true;
block_elements *= src_current_dim_len;
block_src_elements *= src_current_dim_len;
src_choose_idx--;
} else {
// 需要分割
SplitDim split_dim = {
src_idx, // choose_idx
src_num_per_block_int, // num_per_block
src_num_per_grid, // num_per_grid
0, // array_struct_idx_block (待更新)
0, // array_struct_idx_grid (待更新)
src_current_dim_len // 原始维度长度
};
split_dims.push_back(split_dim);
}
}
// 处理目标维度
if (dst_num_per_block_int > 1) {
if (dst_num_per_grid == 1) {
// 可以完全放入块
block_dim_choose[dst_idx] = true;
block_elements *= dst_current_dim_len;
block_dst_elements *= dst_current_dim_len;
dst_choose_idx--;
} else {
// 需要分割
SplitDim split_dim = {
dst_idx, // choose_idx
dst_num_per_block_int, // num_per_block
dst_num_per_grid, // num_per_grid
0, // array_struct_idx_block (待更新)
0, // array_struct_idx_grid (待更新)
dst_current_dim_len // 原始维度长度
};
split_dims.push_back(split_dim);
}
}
break; if (block_elements == 1 && ndim > 0) {
} size_t dim_idx = dim_order[0];
size_t dim_len = shape[dim_idx];
if (dim_len <= (size_t)max_threads) {
block_dim_choose[dim_idx] = true;
block_elements = dim_len;
} else {
size_t num_per_block = std::min(dim_len, (size_t)max_threads);
size_t num_per_grid = (dim_len + num_per_block - 1) / num_per_block;
SplitDim split_dim = {
dim_idx,
num_per_block,
num_per_grid,
0,
0,
dim_len};
split_dims.push_back(split_dim);
block_elements = num_per_block;
} }
} }
// 准备block维度相关参数
size_t block_dim = 0; size_t block_dim = 0;
size_t block_len_total = 1; size_t block_len_total = block_elements;
std::vector<ARRAY_TYPE_SIZE> block_len; std::vector<ARRAY_TYPE_SIZE> block_len;
std::vector<ARRAY_TYPE_STRIDE> src_block_stride; std::vector<ARRAY_TYPE_STRIDE> src_block_stride;
...@@ -279,46 +204,40 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta ...@@ -279,46 +204,40 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
std::vector<ARRAY_TYPE_STRIDE> src_grid_stride; std::vector<ARRAY_TYPE_STRIDE> src_grid_stride;
std::vector<ARRAY_TYPE_STRIDE> dst_grid_stride; std::vector<ARRAY_TYPE_STRIDE> dst_grid_stride;
// 处理block维度,填充block_len和block_stride
for (size_t i = 0; i < ndim; ++i) { for (size_t i = 0; i < ndim; ++i) {
if (block_dim_choose[i]) { if (block_dim_choose[i]) {
block_len.push_back(shape[i]); block_len.push_back(shape[i]);
src_block_stride.push_back(dims[i].src_stride); src_block_stride.push_back(dims[i].src_stride);
dst_block_stride.push_back(dims[i].dst_stride); dst_block_stride.push_back(dims[i].dst_stride);
block_dim += 1; block_dim += 1;
block_len_total *= shape[i];
} }
// 处理分割维度的block部分
for (size_t j = 0; j < split_dims.size(); ++j) { for (size_t j = 0; j < split_dims.size(); ++j) {
if (i == split_dims[j].choose_idx) { if (i == split_dims[j].choose_idx) {
block_len.push_back(split_dims[j].num_per_block); block_len.push_back(split_dims[j].num_per_block);
src_block_stride.push_back(dims[i].src_stride); src_block_stride.push_back(dims[i].src_stride);
dst_block_stride.push_back(dims[i].dst_stride); dst_block_stride.push_back(dims[i].dst_stride);
split_dims[j].array_struct_idx_block = block_dim; split_dims[j].array_struct_idx_block = static_cast<int>(block_dim);
block_dim += 1; block_dim += 1;
block_len_total *= split_dims[j].num_per_block;
} }
} }
} }
// 处理grid维度,填充grid_len和grid_stride
for (size_t i = 0; i < ndim; ++i) { for (size_t i = 0; i < ndim; ++i) {
if (!block_dim_choose[i]) { if (!block_dim_choose[i]) {
bool is_split = false; bool is_split = false;
// 检查是否是分割维度
for (size_t j = 0; j < split_dims.size(); ++j) { for (size_t j = 0; j < split_dims.size(); ++j) {
if (i == split_dims[j].choose_idx) { if (i == split_dims[j].choose_idx) {
is_split = true; is_split = true;
grid_len.push_back(split_dims[j].num_per_grid); grid_len.push_back(split_dims[j].num_per_grid);
src_grid_stride.push_back(dims[i].src_stride * split_dims[j].num_per_block); src_grid_stride.push_back(dims[i].src_stride * split_dims[j].num_per_block);
dst_grid_stride.push_back(dims[i].dst_stride * split_dims[j].num_per_block); dst_grid_stride.push_back(dims[i].dst_stride * split_dims[j].num_per_block);
split_dims[j].array_struct_idx_grid = grid_len.size() - 1; split_dims[j].array_struct_idx_grid = static_cast<int>(grid_len.size() - 1);
break;
} }
} }
// 如果不是分割维度,则作为完整的grid维度
if (!is_split) { if (!is_split) {
grid_len.push_back(shape[i]); grid_len.push_back(shape[i]);
src_grid_stride.push_back(dims[i].src_stride); src_grid_stride.push_back(dims[i].src_stride);
...@@ -327,17 +246,14 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta ...@@ -327,17 +246,14 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
} }
} }
// 如果grid_len为空,添加一个默认值
if (grid_len.empty()) { if (grid_len.empty()) {
grid_len.push_back(1); grid_len.push_back(1);
src_grid_stride.push_back(0); src_grid_stride.push_back(0);
dst_grid_stride.push_back(0); dst_grid_stride.push_back(0);
} }
// 处理约束条件 - 使用与Rust版本相似的逻辑
std::vector<Constraint<ARRAY_TYPE_SIZE>> constraints; std::vector<Constraint<ARRAY_TYPE_SIZE>> constraints;
// 限制最多处理2个约束条件
for (size_t i = 0; i < split_dims.size(); ++i) { for (size_t i = 0; i < split_dims.size(); ++i) {
if (split_dims[i].dim_len % split_dims[i].num_per_block == 0) { if (split_dims[i].dim_len % split_dims[i].num_per_block == 0) {
continue; continue;
...@@ -348,9 +264,12 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta ...@@ -348,9 +264,12 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
constraint.grid_div_block = split_dims[i].num_per_block; constraint.grid_div_block = split_dims[i].num_per_block;
constraint.total_len = split_dims[i].dim_len; constraint.total_len = split_dims[i].dim_len;
constraints.push_back(constraint); constraints.push_back(constraint);
if (constraints.size() >= 2) {
break;
}
} }
// 设置参数
params.block_dim = block_dim; params.block_dim = block_dim;
params.block_len_total = block_len_total; params.block_len_total = block_len_total;
params.block_len = block_len; params.block_len = block_len;
...@@ -365,7 +284,6 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta ...@@ -365,7 +284,6 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
return utils::Result<RearrangeParams>(params); return utils::Result<RearrangeParams>(params);
} }
// 带约束的内核启动模板函数
template <unsigned int BLOCK_SIZE> template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel( infiniStatus_t launchKernel(
void *y, void *y,
...@@ -375,30 +293,28 @@ infiniStatus_t launchKernel( ...@@ -375,30 +293,28 @@ infiniStatus_t launchKernel(
size_t unit_size, size_t unit_size,
musaStream_t stream) { musaStream_t stream) {
// 获取内核函数 RearrangeParams params_copy = params;
RearrangeParams params_copy = params; // 创建一个非const副本
auto kernel_func_result = getRearrangeKernel(params_copy); auto kernel_func_result = getRearrangeKernel(params_copy);
CHECK_RESULT(kernel_func_result); CHECK_RESULT(kernel_func_result);
auto kernel_func = kernel_func_result.take(); auto kernel_func = kernel_func_result.take();
// 创建非const的临时变量
size_t block_dim = params.block_dim; size_t block_dim = params.block_dim;
size_t block_len_total = params.block_len_total; size_t block_len_total = params.block_len_total;
// 计算对齐后的线程块大小(Block Size)以适配 MUSA 架构的Warp特性 // Calculate aligned thread block size to match MUSA architecture's Warp characteristics:
// - MUSA 架构以 32 线程为基本调度单位(1个 Warp // - MUSA architecture uses 32 threads as the fundamental scheduling unit (1 Warp).
// - 通过向上取整到最近的 32 的倍数,确保线程块包含完整的 Warp // - Round up to the nearest multiple of 32 to ensure the block consists of full Warps.
// - MUSA 似乎不支持非 32 整数倍的计算 // - MUSA hardware/scheduler typically requires thread counts to be multiples of 32.
size_t aligned_block_size = ((block_len_total + 31) / 32) * 32; size_t aligned_block_size = ((block_len_total + 31) / 32) * 32;
block_len_total = aligned_block_size; // block_len_total = aligned_block_size;
// 确保对齐后的线程块大小不超过硬件/模板限制 // Ensure the aligned block size does not exceed hardware or template-defined limits.
if (aligned_block_size > BLOCK_SIZE) { if (aligned_block_size > BLOCK_SIZE) {
aligned_block_size = BLOCK_SIZE; // 降级到安全值 aligned_block_size = BLOCK_SIZE;
} }
// 检查向量尺寸是否合理 // Validate that vector dimensions are sufficient for the specified block dimension.
if (params.block_len.size() < block_dim || params.src_block_stride.size() < block_dim || params.dst_block_stride.size() < block_dim) { if (params.block_len.size() < block_dim || params.src_block_stride.size() < block_dim || params.dst_block_stride.size() < block_dim) {
return INFINI_STATUS_BAD_PARAM; return INFINI_STATUS_BAD_PARAM;
} }
...@@ -428,18 +344,18 @@ infiniStatus_t launchKernel( ...@@ -428,18 +344,18 @@ infiniStatus_t launchKernel(
const_cast<void *>(static_cast<const void *>(params.dst_grid_stride.data())), const_cast<void *>(static_cast<const void *>(params.dst_grid_stride.data())),
const_cast<void *>(static_cast<const void *>(constraints_data))}; const_cast<void *>(static_cast<const void *>(constraints_data))};
// musaLaunchKernel 的 blockDim 似乎必须满足: // The blockDim for musaLaunchKernel must satisfy the following constraints:
// - 是32的整数倍(适配 MUSA Warp 调度机制) // - Must be a multiple of 32 (aligned with MUSA's Warp scheduling mechanism).
// - 不小于实际需要处理的元素数(block_len_total // - Must be greater than or equal to the number of elements to process (block_len_total).
// - 向上取整,数学等效:ceil(n / 32) * 32 // - Math equivalent: ceil(n / 32) * 32 (rounding up to the nearest warp).
CHECK_OR_RETURN(musaLaunchKernel( CHECK_OR_RETURN(musaLaunchKernel(
kernel_func, kernel_func,
grid_size, aligned_block_size, static_cast<unsigned int>(grid_size), static_cast<unsigned int>(aligned_block_size),
args, 0, stream) args, 0, stream)
== musaSuccess, == musaSuccess,
INFINI_STATUS_INTERNAL_ERROR); INFINI_STATUS_INTERNAL_ERROR);
// 设备同步,检查内核执行是否出错 // Synchronize the device to detect potential asynchronous kernel execution errors.
musaError_t err = musaDeviceSynchronize(); musaError_t err = musaDeviceSynchronize();
if (err != musaSuccess) { if (err != musaSuccess) {
std::cerr << "[ERROR] musaDeviceSynchronize failed: " << err << std::endl; std::cerr << "[ERROR] musaDeviceSynchronize failed: " << err << std::endl;
...@@ -456,38 +372,27 @@ infiniStatus_t Descriptor::calculate( ...@@ -456,38 +372,27 @@ infiniStatus_t Descriptor::calculate(
auto musa_stream = reinterpret_cast<musaStream_t>(stream); auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// 如果没有维度,直接进行内存拷贝
if (_meta.ndim() == 0) { if (_meta.ndim() == 0) {
auto err = musaMemcpyAsync(y, x, _meta.unit(), musaMemcpyDeviceToDevice, musa_stream);
if (err != musaSuccess) {
return INFINI_STATUS_INTERNAL_ERROR;
}
CHECK_OR_RETURN(musaMemcpyAsync(y, x, _meta.unit(), musaMemcpyDeviceToDevice, musa_stream) == musaSuccess, CHECK_OR_RETURN(musaMemcpyAsync(y, x, _meta.unit(), musaMemcpyDeviceToDevice, musa_stream) == musaSuccess,
INFINI_STATUS_INTERNAL_ERROR); INFINI_STATUS_INTERNAL_ERROR);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
// 获取设备属性
int max_threads = _opaque->internal->maxThreadsPerBlock(); int max_threads = _opaque->internal->maxThreadsPerBlock();
// 准备参数
auto params_result = prepareRearrangeParams(_meta, std::min(MOORE_BLOCK_SIZE_1024, max_threads)); auto params_result = prepareRearrangeParams(_meta, std::min(MOORE_BLOCK_SIZE_1024, max_threads));
CHECK_RESULT(params_result); CHECK_RESULT(params_result);
auto params = params_result.take(); auto params = params_result.take();
// 计算grid大小
size_t grid_size = 1; size_t grid_size = 1;
for (size_t i = 0; i < params.grid_len.size(); ++i) { for (size_t i = 0; i < params.grid_len.size(); ++i) {
grid_size *= params.grid_len[i]; grid_size *= params.grid_len[i];
} }
// 检查grid大小是否为0
if (grid_size == 0) { if (grid_size == 0) {
return INFINI_STATUS_BAD_PARAM; return INFINI_STATUS_BAD_PARAM;
} }
// 根据设备属性选择合适的内核
infiniStatus_t status = INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; infiniStatus_t status = INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
size_t block_size = params.block_len_total; size_t block_size = params.block_len_total;
...@@ -497,7 +402,6 @@ infiniStatus_t Descriptor::calculate( ...@@ -497,7 +402,6 @@ infiniStatus_t Descriptor::calculate(
} else if (block_size <= MOORE_BLOCK_SIZE_1024) { } else if (block_size <= MOORE_BLOCK_SIZE_1024) {
status = launchKernel<MOORE_BLOCK_SIZE_1024>(y, x, grid_size, params, _meta.unit(), musa_stream); status = launchKernel<MOORE_BLOCK_SIZE_1024>(y, x, grid_size, params, _meta.unit(), musa_stream);
} else { } else {
std::cerr << "[ERROR] block_size=" << block_size << " exceeds max supported" << std::endl;
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
} }
......
...@@ -345,12 +345,10 @@ infiniStatus_t launchKernel( ...@@ -345,12 +345,10 @@ infiniStatus_t launchKernel(
const_cast<void *>(static_cast<const void *>(params.dst_grid_stride.data())), const_cast<void *>(static_cast<const void *>(params.dst_grid_stride.data())),
const_cast<void *>(static_cast<const void *>(constraints_data))}; const_cast<void *>(static_cast<const void *>(constraints_data))};
CHECK_OR_RETURN(cudaLaunchKernel( CHECK_CUDA(cudaLaunchKernel(
kernel_func, kernel_func,
static_cast<unsigned int>(grid_size), static_cast<unsigned int>(BLOCK_SIZE), static_cast<unsigned int>(grid_size), static_cast<unsigned int>(BLOCK_SIZE),
args, 0, stream) args, 0, stream));
== cudaSuccess,
INFINI_STATUS_INTERNAL_ERROR);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
#include "ascend/rearrange_ascend.h" #include "ascend/rearrange_ascend.h"
#endif #endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API)
#include "nvidia/rearrange_nvidia.cuh" #include "nvidia/rearrange_nvidia.cuh"
#endif #endif
#ifdef ENABLE_CAMBRICON_API #ifdef ENABLE_CAMBRICON_API
...@@ -52,6 +52,9 @@ __C infiniStatus_t infiniopCreateRearrangeDescriptor( ...@@ -52,6 +52,9 @@ __C infiniStatus_t infiniopCreateRearrangeDescriptor(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia); CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia); CREATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
...@@ -102,6 +105,9 @@ __C infiniStatus_t infiniopRearrange( ...@@ -102,6 +105,9 @@ __C infiniStatus_t infiniopRearrange(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia); CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
...@@ -150,6 +156,9 @@ __C infiniStatus_t infiniopDestroyRearrangeDescriptor( ...@@ -150,6 +156,9 @@ __C infiniStatus_t infiniopDestroyRearrangeDescriptor(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia); DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia); DELETE(INFINI_DEVICE_QY, nvidia);
#endif #endif
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "../../../../../build/ninetoothed/relu.h" #include "../../../../../build/ninetoothed/relu.h"
#include "../../../devices/metax/metax_common.h" #include "../../../devices/metax/metax_common.h"
#include "../../../ninetoothed/utils.h"
#include "relu_metax.h" #include "relu_metax.h"
namespace op::relu::metax { namespace op::relu::metax {
...@@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate( ...@@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate(
} }
const auto &ndim{_info.getNdim()}; const auto &ndim{_info.getNdim()};
const auto &x_shape_{_info.getInputShape(0)};
const auto &x_strides_{_info.getInputStrides(0)}; auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}};
std::vector<uint64_t> x_shape_vec{x_shape_, x_shape_ + ndim}; auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}};
std::vector<int64_t> x_strides_vec{x_strides_, x_strides_ + ndim};
auto x_data{const_cast<void *>(inputs[0])};
auto x_shape{x_shape_vec.data()};
auto x_strides{x_strides_vec.data()};
const NineToothedTensor x{x_data, x_shape, x_strides};
const auto &y_shape_{_info.getOutputShape()};
const auto &y_strides_{_info.getOutputStrides()};
std::vector<uint64_t> y_shape_vec{y_shape_, y_shape_ + ndim};
std::vector<int64_t> y_strides_vec{y_strides_, y_strides_ + ndim};
auto y_data{output};
auto y_shape{y_shape_vec.data()};
auto y_strides{y_strides_vec.data()};
const NineToothedTensor y{y_data, y_shape, y_strides};
constexpr auto block_size{1024}; constexpr auto block_size{1024};
switch (_dtype) { switch (_dtype) {
......
#ifdef ENABLE_NINETOOTHED #ifdef ENABLE_NINETOOTHED
#include "../../../../../build/ninetoothed/relu.h" #include "../../../../../build/ninetoothed/relu.h"
#include "../../../ninetoothed/utils.h"
#endif #endif
#include "../../../devices/nvidia/nvidia_common.cuh" #include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../elementwise/nvidia/elementwise_nvidia.cuh" #include "../../../elementwise/nvidia/elementwise_nvidia.cuh"
...@@ -45,22 +46,10 @@ infiniStatus_t Descriptor::calculate( ...@@ -45,22 +46,10 @@ infiniStatus_t Descriptor::calculate(
} }
#ifdef ENABLE_NINETOOTHED #ifdef ENABLE_NINETOOTHED
const auto &ndim{_info.getNdim()}; const auto &ndim{_info.getNdim()};
const auto &x_shape_{_info.getInputShape(0)};
const auto &x_strides_{_info.getInputStrides(0)}; auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}};
std::vector<uint64_t> x_shape_vec{x_shape_, x_shape_ + ndim}; auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}};
std::vector<int64_t> x_strides_vec{x_strides_, x_strides_ + ndim};
auto x_data{const_cast<void *>(inputs[0])};
auto x_shape{x_shape_vec.data()};
auto x_strides{x_strides_vec.data()};
const NineToothedTensor x{x_data, x_shape, x_strides};
const auto &y_shape_{_info.getOutputShape()};
const auto &y_strides_{_info.getOutputStrides()};
std::vector<uint64_t> y_shape_vec{y_shape_, y_shape_ + ndim};
std::vector<int64_t> y_strides_vec{y_strides_, y_strides_ + ndim};
auto y_data{output};
auto y_shape{y_shape_vec.data()};
auto y_strides{y_strides_vec.data()};
const NineToothedTensor y{y_data, y_shape, y_strides};
constexpr auto block_size{1024}; constexpr auto block_size{1024};
switch (_dtype) { switch (_dtype) {
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
#include "cpu/relu_cpu.h" #include "cpu/relu_cpu.h"
#endif #endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/relu_nvidia.cuh" #include "nvidia/relu_nvidia.cuh"
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
...@@ -46,6 +46,9 @@ __C infiniStatus_t infiniopCreateReluDescriptor( ...@@ -46,6 +46,9 @@ __C infiniStatus_t infiniopCreateReluDescriptor(
#ifdef ENABLE_NINETOOTHED #ifdef ENABLE_NINETOOTHED
CREATE(INFINI_DEVICE_METAX, metax); CREATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif #endif
default: default:
...@@ -80,6 +83,10 @@ __C infiniStatus_t infiniopGetReluWorkspaceSize(infiniopReluDescriptor_t desc, s ...@@ -80,6 +83,10 @@ __C infiniStatus_t infiniopGetReluWorkspaceSize(infiniopReluDescriptor_t desc, s
GET(INFINI_DEVICE_METAX, metax) GET(INFINI_DEVICE_METAX, metax)
#endif #endif
#endif #endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -119,6 +126,9 @@ __C infiniStatus_t infiniopRelu( ...@@ -119,6 +126,9 @@ __C infiniStatus_t infiniopRelu(
#ifdef ENABLE_NINETOOTHED #ifdef ENABLE_NINETOOTHED
CALCULATE(INFINI_DEVICE_METAX, metax); CALCULATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif #endif
default: default:
...@@ -154,6 +164,9 @@ infiniopDestroyReluDescriptor(infiniopReluDescriptor_t desc) { ...@@ -154,6 +164,9 @@ infiniopDestroyReluDescriptor(infiniopReluDescriptor_t desc) {
#ifdef ENABLE_NINETOOTHED #ifdef ENABLE_NINETOOTHED
DELETE(INFINI_DEVICE_METAX, metax); DELETE(INFINI_DEVICE_METAX, metax);
#endif #endif
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif #endif
default: default:
......
...@@ -82,7 +82,7 @@ __mlu_global__ void rmsnorm(T *output, const T *input, const Tw *weight, ...@@ -82,7 +82,7 @@ __mlu_global__ void rmsnorm(T *output, const T *input, const Tw *weight,
} }
} else { } else {
// Large vector processing with chunking // Large vector processing with chunking
__bang_write_zero(reduction_buffer, reduce_buffer_size); __bang_write_value(reduction_buffer, reduce_buffer_size, 0);
size_t processed_elements = 0; size_t processed_elements = 0;
while (processed_elements < vector_size) { while (processed_elements < vector_size) {
...@@ -223,9 +223,9 @@ void rmsnormUnion(void *workspace, int core_per_cluster, int cluster_count, cnrt ...@@ -223,9 +223,9 @@ void rmsnormUnion(void *workspace, int core_per_cluster, int cluster_count, cnrt
kernel_dim.x = core_per_cluster; kernel_dim.x = core_per_cluster;
kernel_dim.y = cluster_count; kernel_dim.y = cluster_count;
kernel_dim.z = 1; kernel_dim.z = 1;
kernel_type = CNRT_FUNC_TYPE_UNION1; // Can choose others, but must adapt kernel_type accordingly kernel_type = cnrtFuncTypeUnion1; // Can choose others, but must adapt kernel_type accordingly
int dimsize = shape[ndim - 1]; // Length of operation dimension int dimsize = shape[ndim - 1]; // Length of operation dimension
int dim_s; // dim_s is the next power of 2 greater than dimsize int dim_s; // dim_s is the next power of 2 greater than dimsize
float mi = log2(dimsize); float mi = log2(dimsize);
if (floor(mi) == mi) { if (floor(mi) == mi) {
dim_s = dimsize; dim_s = dimsize;
......
...@@ -117,12 +117,14 @@ infiniStatus_t Descriptor::calculate( ...@@ -117,12 +117,14 @@ infiniStatus_t Descriptor::calculate(
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream); auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
// launch kernel with different block sizes // launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_4096>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_2048>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream)); CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_512>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream)); CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_512>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_4096>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream));
} else { } else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
} }
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
#include "cpu/rms_norm_cpu.h" #include "cpu/rms_norm_cpu.h"
#endif #endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API)
#include "nvidia/rms_norm_nvidia.cuh" #include "nvidia/rms_norm_nvidia.cuh"
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
...@@ -52,6 +52,9 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor( ...@@ -52,6 +52,9 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia); CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia); CREATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
...@@ -97,6 +100,9 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d ...@@ -97,6 +100,9 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia); GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia); GET(INFINI_DEVICE_QY, nvidia);
#endif #endif
...@@ -143,6 +149,9 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works ...@@ -143,6 +149,9 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia); CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
...@@ -188,6 +197,9 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t ...@@ -188,6 +197,9 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia); DESTROY(INFINI_DEVICE_QY, nvidia);
#endif #endif
......
...@@ -40,8 +40,9 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, ...@@ -40,8 +40,9 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
const Tdata *sin_table, const Tdata *sin_table,
const Tdata *cos_table, const Tdata *cos_table,
cnrtQueue_t queue) { cnrtQueue_t queue) {
auto dimx = uint32_t(info.seqlen); auto batch_size = uint32_t(info.batch);
auto dimy = uint32_t(info.nhead); auto seqlen = uint32_t(info.seqlen);
auto nhead = uint32_t(info.nhead);
auto table_dim = uint32_t(info.table_dim); auto table_dim = uint32_t(info.table_dim);
cnrtDim3_t k_dim; cnrtDim3_t k_dim;
...@@ -51,14 +52,14 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, ...@@ -51,14 +52,14 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
k_dim.x = 4; k_dim.x = 4;
k_dim.y = 1; k_dim.y = 1;
k_dim.z = 1; k_dim.z = 1;
k_type = CNRT_FUNC_TYPE_UNION1; k_type = cnrtFuncTypeUnion1;
// Launch kernel // Launch kernel with batch dimension
ropeKernel<<<k_dim, k_type, queue>>>( ropeKernel<<<k_dim, k_type, queue>>>(
y, x, pos_ids, sin_table, cos_table, y, x, pos_ids, sin_table, cos_table,
dimx, dimy, table_dim, batch_size, seqlen, nhead, table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_seqlen, info.x_stride_nhead, info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead,
info.algo); info.algo);
cnrtQueueSync(queue); cnrtQueueSync(queue);
......
...@@ -62,11 +62,14 @@ __mlu_global__ void ropeKernel( ...@@ -62,11 +62,14 @@ __mlu_global__ void ropeKernel(
const Tindex *pos_ids, const Tindex *pos_ids,
const Tdata *sin_table, const Tdata *sin_table,
const Tdata *cos_table, const Tdata *cos_table,
uint32_t batch_size,
uint32_t seqlen, uint32_t seqlen,
uint32_t nhead, uint32_t nhead,
uint32_t table_dim, uint32_t table_dim,
ptrdiff_t y_stride_batch,
ptrdiff_t y_stride_seqlen, ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead, ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_batch,
ptrdiff_t x_stride_seqlen, ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead, ptrdiff_t x_stride_nhead,
infiniopRoPEAlgo_t algo) { infiniopRoPEAlgo_t algo) {
...@@ -106,7 +109,7 @@ __mlu_global__ void ropeKernel( ...@@ -106,7 +109,7 @@ __mlu_global__ void ropeKernel(
} }
// Task distribution // Task distribution
const int batch_volume = seqlen * nhead; const int batch_volume = batch_size * seqlen * nhead;
const int remaining_tasks = batch_volume % taskDim; const int remaining_tasks = batch_volume % taskDim;
const int base_tasks_per_core = batch_volume / taskDim; const int base_tasks_per_core = batch_volume / taskDim;
const int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0); const int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
...@@ -136,13 +139,35 @@ __mlu_global__ void ropeKernel( ...@@ -136,13 +139,35 @@ __mlu_global__ void ropeKernel(
// Main processing loop // Main processing loop
for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) { for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
int seq_idx = i / nhead; // Calculate 3D indices from flattened task index
int batch_idx = i / (seqlen * nhead);
int seq_idx = (i % (seqlen * nhead)) / nhead;
int head_idx = i % nhead; int head_idx = i % nhead;
int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead; // Calculate offsets with batch dimension
int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead; // Note: For GPT-NeoX, the stride calculations might be different
int out_offset = batch_idx * y_stride_batch + seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
int in_offset = batch_idx * x_stride_batch + seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
// Get position index for this sequence
// Position IDs are shared across batches or per batch depending on input
Tindex pos_idx;
if (use_pos_ids_buffer) {
// Position IDs loaded in NRAM
pos_idx = srcP[seq_idx];
} else {
// Position IDs in global memory
// Handle both cases: position IDs shape could be [seqlen] or [batch_size, seqlen]
if (batch_size > 1) {
// Assume position IDs have shape [batch_size, seqlen]
int pos_flat_idx = batch_idx * seqlen + seq_idx;
pos_idx = pos_ids[pos_flat_idx];
} else {
// Single batch case: position IDs shape is [seqlen]
pos_idx = pos_ids[seq_idx];
}
}
Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
int rot_offset = pos_idx * table_dim; int rot_offset = pos_idx * table_dim;
int processed = 0; int processed = 0;
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
#include "cpu/rope_cpu.h" #include "cpu/rope_cpu.h"
#endif #endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API)
#include "nvidia/rope_nvidia.cuh" #include "nvidia/rope_nvidia.cuh"
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
...@@ -56,6 +56,9 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor( ...@@ -56,6 +56,9 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia); CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia); CREATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
...@@ -101,6 +104,9 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, ...@@ -101,6 +104,9 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia); GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia); GET(INFINI_DEVICE_QY, nvidia);
#endif #endif
...@@ -155,6 +161,9 @@ __C infiniStatus_t infiniopRoPE( ...@@ -155,6 +161,9 @@ __C infiniStatus_t infiniopRoPE(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia); CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
...@@ -201,6 +210,9 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) { ...@@ -201,6 +210,9 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia); DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia); DELETE(INFINI_DEVICE_QY, nvidia);
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment