Unverified Commit 012df56c authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #963 from InfiniTensor/issue/523-020

issue/523 - switched to cambricon mlu 1.22 interface
parents f1b8ab64 aac54e1f
...@@ -62,7 +62,7 @@ infiniStatus_t commInitAll( ...@@ -62,7 +62,7 @@ infiniStatus_t commInitAll(
for (int i = 0; i < ndevice; i++) { for (int i = 0; i < ndevice; i++) {
rank_list[i] = i; rank_list[i] = i;
CHECK_INTERNAL(cnrtSetDevice(device_ids[i]), CNRT_RET_SUCCESS); CHECK_INTERNAL(cnrtSetDevice(device_ids[i]), cnrtSuccess);
} }
CHECK_CNCL(cnclInitComms(cncl_comms.data(), ndevice, CHECK_CNCL(cnclInitComms(cncl_comms.data(), ndevice,
......
...@@ -127,8 +127,8 @@ private: ...@@ -127,8 +127,8 @@ private:
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size; const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// Copy input pointer array and metadata to device // Copy input pointer array and metadata to device
CNRT_CHECK(cnrtMemcpy(workspace, (void *)h_inputs_arr, input_arr_size, CNRT_MEM_TRANS_DIR_HOST2DEV)); CNRT_CHECK(cnrtMemcpy(workspace, (void *)h_inputs_arr, input_arr_size, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtMemcpy((void *)d_meta_start, (void *)info_meta_start, info.getMetaMemSize(), CNRT_MEM_TRANS_DIR_HOST2DEV)); CNRT_CHECK(cnrtMemcpy((void *)d_meta_start, (void *)info_meta_start, info.getMetaMemSize(), cnrtMemcpyHostToDev));
// Setup pointers to device memory regions // Setup pointers to device memory regions
d_inputs_arr = reinterpret_cast<const void **>(workspace); d_inputs_arr = reinterpret_cast<const void **>(workspace);
......
...@@ -248,10 +248,10 @@ void launchElementwiseKernelWrapper( ...@@ -248,10 +248,10 @@ void launchElementwiseKernelWrapper(
dim.z = 1; dim.z = 1;
// Choose kernel type based on problem characteristics // Choose kernel type based on problem characteristics
cnrtFunctionType_t func_type = CNRT_FUNC_TYPE_BLOCK; cnrtFunctionType_t func_type = cnrtFuncTypeBlock;
if (output_size > 1024 * 1024 && output_contiguous) { if (output_size > 1024 * 1024 && output_contiguous) {
// For large contiguous operations, use UNION type // For large contiguous operations, use UNION type
func_type = CNRT_FUNC_TYPE_UNION1; func_type = cnrtFuncTypeUnion1;
} }
// Launch the kernel with optimal configuration // Launch the kernel with optimal configuration
......
...@@ -131,7 +131,7 @@ void causalSoftmaxUnion(void *workspace, int core_per_cluster, int cluster_count ...@@ -131,7 +131,7 @@ void causalSoftmaxUnion(void *workspace, int core_per_cluster, int cluster_count
kernel_dim.x = core_per_cluster; kernel_dim.x = core_per_cluster;
kernel_dim.y = cluster_count; kernel_dim.y = cluster_count;
kernel_dim.z = 1; kernel_dim.z = 1;
kernel_type = CNRT_FUNC_TYPE_UNION1; kernel_type = cnrtFuncTypeUnion1;
// Launch kernel // Launch kernel
causalSoftmax<T><<<kernel_dim, kernel_type, queue>>>( causalSoftmax<T><<<kernel_dim, kernel_type, queue>>>(
......
...@@ -15,8 +15,8 @@ struct Descriptor::Opaque { ...@@ -15,8 +15,8 @@ struct Descriptor::Opaque {
cnnlDestroyTensorDescriptor(a); cnnlDestroyTensorDescriptor(a);
cnnlDestroyTensorDescriptor(b); cnnlDestroyTensorDescriptor(b);
cnnlDestroyTensorDescriptor(c); cnnlDestroyTensorDescriptor(c);
cnnlMatMulDescDestroy(op); cnnlDestroyMatMulDescriptor(op);
cnnlMatMulAlgoDestroy(algo); cnnlDestroyMatMulAlgo(algo);
cnnlDestroyMatMulHeuristicResult(algoResult); cnnlDestroyMatMulHeuristicResult(algoResult);
} }
}; };
...@@ -85,8 +85,8 @@ infiniStatus_t Descriptor::create( ...@@ -85,8 +85,8 @@ infiniStatus_t Descriptor::create(
cnnlMatMulDescriptor_t op; cnnlMatMulDescriptor_t op;
cnnlMatMulAlgo_t algo; cnnlMatMulAlgo_t algo;
cnnlMatMulHeuristicResult_t algoResult; cnnlMatMulHeuristicResult_t algoResult;
CHECK_BANG(cnnlMatMulDescCreate(&op)); CHECK_BANG(cnnlCreateMatMulDescriptor(&op));
CHECK_BANG(cnnlMatMulAlgoCreate(&algo)); CHECK_BANG(cnnlCreateMatMulAlgo(&algo));
CHECK_BANG(cnnlCreateMatMulHeuristicResult(&algoResult)); CHECK_BANG(cnnlCreateMatMulHeuristicResult(&algoResult));
int32_t use_stride = true; int32_t use_stride = true;
CHECK_BANG(cnnlSetMatMulDescAttr( CHECK_BANG(cnnlSetMatMulDescAttr(
...@@ -101,7 +101,7 @@ infiniStatus_t Descriptor::create( ...@@ -101,7 +101,7 @@ infiniStatus_t Descriptor::create(
(cnrtQueue_t) nullptr, (cnrtQueue_t) nullptr,
[&](cnnlHandle_t _handle) { [&](cnnlHandle_t _handle) {
CHECK_BANG( CHECK_BANG(
cnnlGetBatchMatMulAlgoHeuristic( cnnlGetBatchMatMulExAlgoHeuristic(
_handle, _handle,
op, a, b, c, op, a, b, c,
NULL, 1, &algoResult, &count)); NULL, 1, &algoResult, &count));
...@@ -109,7 +109,7 @@ infiniStatus_t Descriptor::create( ...@@ -109,7 +109,7 @@ infiniStatus_t Descriptor::create(
})); }));
size_t workspace_size; size_t workspace_size;
CHECK_BANG(cnnlGetBatchMatMulHeuristicResult(algoResult, algo, &workspace_size)); CHECK_BANG(cnnlGetBatchMatMulExHeuristicResult(algoResult, algo, &workspace_size));
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
dtype, info, workspace_size, dtype, info, workspace_size,
...@@ -135,7 +135,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -135,7 +135,7 @@ infiniStatus_t Descriptor::calculate(
CHECK_STATUS(_opaque->internal->useCnnl( CHECK_STATUS(_opaque->internal->useCnnl(
(cnrtQueue_t)stream, (cnrtQueue_t)stream,
[&](cnnlHandle_t handle) { [&](cnnlHandle_t handle) {
CHECK_BANG(cnnlBatchMatMulBCast_v2( CHECK_BANG(cnnlBatchMatMulEx(
handle, handle,
_opaque->op, _opaque->op,
_opaque->algo, _opaque->algo,
......
...@@ -534,13 +534,13 @@ struct Algo { ...@@ -534,13 +534,13 @@ struct Algo {
if constexpr (std::is_same<Tval_, float>::value) { if constexpr (std::is_same<Tval_, float>::value) {
auto logits = reinterpret_cast<const float *>(probs); auto logits = reinterpret_cast<const float *>(probs);
argMax<<<dim, CNRT_FUNC_TYPE_BLOCK, queue>>>(logits, result, gdram_indices, voc); argMax<<<dim, cnrtFuncTypeBlock, queue>>>(logits, result, gdram_indices, voc);
} else if constexpr (std::is_same<Tval_, CustomFloat16>::value) { } else if constexpr (std::is_same<Tval_, CustomFloat16>::value) {
auto logits = reinterpret_cast<const half *>(probs); auto logits = reinterpret_cast<const half *>(probs);
argMax<<<dim, CNRT_FUNC_TYPE_BLOCK, queue>>>(logits, result, gdram_indices, voc); argMax<<<dim, cnrtFuncTypeBlock, queue>>>(logits, result, gdram_indices, voc);
} else if constexpr (std::is_same<Tval_, CustomBFloat16>::value) { } else if constexpr (std::is_same<Tval_, CustomBFloat16>::value) {
auto logits = reinterpret_cast<const bfloat16_t *>(probs); auto logits = reinterpret_cast<const bfloat16_t *>(probs);
argMax<<<dim, CNRT_FUNC_TYPE_BLOCK, queue>>>(logits, result, gdram_indices, voc); argMax<<<dim, cnrtFuncTypeBlock, queue>>>(logits, result, gdram_indices, voc);
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
...@@ -575,10 +575,10 @@ struct Algo { ...@@ -575,10 +575,10 @@ struct Algo {
const int max_num = SRC_MAX_SIZE / sizeof(float); const int max_num = SRC_MAX_SIZE / sizeof(float);
if (voc >= task_num * max_num) { if (voc >= task_num * max_num) {
randomSampleKernelLarge<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>( randomSampleKernelLarge<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature); logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} else { } else {
randomSampleKernel<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>( randomSampleKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature); logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} }
} else if constexpr (std::is_same<Tval_, CustomFloat16>::value) { } else if constexpr (std::is_same<Tval_, CustomFloat16>::value) {
...@@ -592,10 +592,10 @@ struct Algo { ...@@ -592,10 +592,10 @@ struct Algo {
const int max_num = SRC_MAX_SIZE / sizeof(half); const int max_num = SRC_MAX_SIZE / sizeof(half);
if (voc >= task_num * max_num) { if (voc >= task_num * max_num) {
randomSampleKernelLarge<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>( randomSampleKernelLarge<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature); logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} else { } else {
randomSampleKernel<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>( randomSampleKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature); logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} }
} else if constexpr (std::is_same<Tval_, CustomBFloat16>::value) { } else if constexpr (std::is_same<Tval_, CustomBFloat16>::value) {
...@@ -609,10 +609,10 @@ struct Algo { ...@@ -609,10 +609,10 @@ struct Algo {
const int max_num = SRC_MAX_SIZE / sizeof(bfloat16_t); const int max_num = SRC_MAX_SIZE / sizeof(bfloat16_t);
if (voc >= task_num * max_num) { if (voc >= task_num * max_num) {
randomSampleKernelLarge<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>( randomSampleKernelLarge<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature); logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} else { } else {
randomSampleKernel<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>( randomSampleKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature); logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} }
} else { } else {
......
...@@ -267,7 +267,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -267,7 +267,7 @@ infiniStatus_t Descriptor::calculate(
dim.x = 4; // Using 4 clusters dim.x = 4; // Using 4 clusters
dim.y = 10; dim.y = 10;
dim.z = 1; dim.z = 1;
func_type = CNRT_FUNC_TYPE_UNION1; func_type = cnrtFuncTypeUnion1;
if (_opaque->use_2d_copy) { if (_opaque->use_2d_copy) {
// Use optimized 2D copy kernel // Use optimized 2D copy kernel
......
...@@ -82,7 +82,7 @@ __mlu_global__ void rmsnorm(T *output, const T *input, const Tw *weight, ...@@ -82,7 +82,7 @@ __mlu_global__ void rmsnorm(T *output, const T *input, const Tw *weight,
} }
} else { } else {
// Large vector processing with chunking // Large vector processing with chunking
__bang_write_zero(reduction_buffer, reduce_buffer_size); __bang_write_value(reduction_buffer, reduce_buffer_size, 0);
size_t processed_elements = 0; size_t processed_elements = 0;
while (processed_elements < vector_size) { while (processed_elements < vector_size) {
...@@ -223,9 +223,9 @@ void rmsnormUnion(void *workspace, int core_per_cluster, int cluster_count, cnrt ...@@ -223,9 +223,9 @@ void rmsnormUnion(void *workspace, int core_per_cluster, int cluster_count, cnrt
kernel_dim.x = core_per_cluster; kernel_dim.x = core_per_cluster;
kernel_dim.y = cluster_count; kernel_dim.y = cluster_count;
kernel_dim.z = 1; kernel_dim.z = 1;
kernel_type = CNRT_FUNC_TYPE_UNION1; // Can choose others, but must adapt kernel_type accordingly kernel_type = cnrtFuncTypeUnion1; // Can choose others, but must adapt kernel_type accordingly
int dimsize = shape[ndim - 1]; // Length of operation dimension int dimsize = shape[ndim - 1]; // Length of operation dimension
int dim_s; // dim_s is the next power of 2 greater than dimsize int dim_s; // dim_s is the next power of 2 greater than dimsize
float mi = log2(dimsize); float mi = log2(dimsize);
if (floor(mi) == mi) { if (floor(mi) == mi) {
dim_s = dimsize; dim_s = dimsize;
......
...@@ -52,7 +52,7 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, ...@@ -52,7 +52,7 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
k_dim.x = 4; k_dim.x = 4;
k_dim.y = 1; k_dim.y = 1;
k_dim.z = 1; k_dim.z = 1;
k_type = CNRT_FUNC_TYPE_UNION1; k_type = cnrtFuncTypeUnion1;
// Launch kernel with batch dimension // Launch kernel with batch dimension
ropeKernel<<<k_dim, k_type, queue>>>( ropeKernel<<<k_dim, k_type, queue>>>(
......
...@@ -50,7 +50,7 @@ __mlu_func__ float sum(const T *source, T *src, float *dst, int num_elements, in ...@@ -50,7 +50,7 @@ __mlu_func__ float sum(const T *source, T *src, float *dst, int num_elements, in
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed); size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) { if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset); __bang_write_value(src, max_batch + offset, 0);
} }
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM); __memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
...@@ -81,7 +81,7 @@ __mlu_func__ float sumBatched(const T *source, T *src, float *dst, int num_eleme ...@@ -81,7 +81,7 @@ __mlu_func__ float sumBatched(const T *source, T *src, float *dst, int num_eleme
size_t remainder = curr_batch % batch_size; size_t remainder = curr_batch % batch_size;
// Ensure NRAM buffer is zeroed // Ensure NRAM buffer is zeroed
__bang_write_zero(src, max_batch + offset); __bang_write_value(src, max_batch + offset, 0);
// Copy data to NRAM // Copy data to NRAM
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM); __memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
...@@ -120,7 +120,7 @@ __mlu_func__ float sumSquared(const T *source, T *src, float *dst, int num_eleme ...@@ -120,7 +120,7 @@ __mlu_func__ float sumSquared(const T *source, T *src, float *dst, int num_eleme
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed); size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) { if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset); __bang_write_value(src, max_batch + offset, 0);
} }
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM); __memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
...@@ -165,7 +165,7 @@ __mlu_func__ float sumSquaredBatched(const T *source, T *src, float *dst, int nu ...@@ -165,7 +165,7 @@ __mlu_func__ float sumSquaredBatched(const T *source, T *src, float *dst, int nu
size_t remainder = curr_batch % batch_size; size_t remainder = curr_batch % batch_size;
// Ensure NRAM buffer is zeroed // Ensure NRAM buffer is zeroed
__bang_write_zero(src, max_batch + offset); __bang_write_value(src, max_batch + offset, 0);
// Copy data to NRAM // Copy data to NRAM
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM); __memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
...@@ -235,7 +235,7 @@ __mlu_func__ float max(const T *source, T *src, float *dst, int num_elements, in ...@@ -235,7 +235,7 @@ __mlu_func__ float max(const T *source, T *src, float *dst, int num_elements, in
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed); size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) { if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset); __bang_write_value(src, max_batch + offset, 0);
} }
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM); __memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
...@@ -264,7 +264,7 @@ __mlu_func__ float maxBatched(const T *source, T *src, float *dst, int num_eleme ...@@ -264,7 +264,7 @@ __mlu_func__ float maxBatched(const T *source, T *src, float *dst, int num_eleme
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed); size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) { if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset); __bang_write_value(src, max_batch + offset, 0);
} }
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM); __memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment