Unverified Commit 82e43d35 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #364 from InfiniTensor/issue/240

Issue/240 - Cambricon Reduce
parents d4b03cf7 00c30a4a
...@@ -9,7 +9,7 @@ struct ParsedArgs { ...@@ -9,7 +9,7 @@ struct ParsedArgs {
int device_id = 0; // CUDA device ID (if specified) int device_id = 0; // CUDA device ID (if specified)
int warmups = 0; // Default to 0 if not given int warmups = 0; // Default to 0 if not given
int iterations = 0; // Default to 0 if not given int iterations = 0; // Default to 0 if not given
double atol = 0.001; // Default absolute tolerance double atol = 0.0015; // Default absolute tolerance
double rtol = 0.001; // Default relative tolerance double rtol = 0.001; // Default relative tolerance
}; };
......
#ifndef __CAUSAL_SOFTMAX_BANG_H__
#define __CAUSAL_SOFTMAX_BANG_H__
#include "../causal_softmax.h"
DESCRIPTOR(bang)
#endif // __CAUSAL_SOFTMAX_BANG_H__
#include "../../../devices/bang/common_bang.h"
#include "../../../reduce/bang/reduce_bang.h"
#include "causal_softmax_bang.h"
__nram__ char nram_buffer[NRAM_MAX_SIZE];
const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 4;
template <typename T>
__mlu_func__ void processSoftmaxStep(T *output, const T *input, float scalar, int num_elements, int stride, bool is_exp_phase) {
// Calculate buffer sizes (split between float and T buffers)
constexpr bool is_half = std::is_same_v<T, half>;
constexpr bool is_bfloat16 = std::is_same_v<T, bfloat16_t>;
constexpr bool is_float = !is_half && !is_bfloat16;
const int chunk_size = SRC_MAX_SIZE / ((is_half || is_bfloat16) ? (2 * sizeof(float)) : sizeof(float));
float *float_buffer = (float *)nram_buffer;
T *temp_buffer = is_float ? nullptr : (T *)(nram_buffer + chunk_size * sizeof(float));
// Common stride configurations
const int src_stride = stride * sizeof(T);
const int dst_stride = stride * sizeof(T);
int processed = 0;
while (processed < num_elements) {
int curr_batch = std::min(chunk_size, num_elements - processed);
// Gather input elements using 2D memcpy
if constexpr (is_float) {
__memcpy(float_buffer, (is_exp_phase ? input : output) + processed * stride, sizeof(float),
GDRAM2NRAM, sizeof(float), src_stride, curr_batch - 1);
} else {
__memcpy(temp_buffer, (is_exp_phase ? input : output) + processed * stride, sizeof(T),
GDRAM2NRAM, sizeof(T), src_stride, curr_batch - 1);
// Convert to float
if constexpr (is_half) {
__bang_half2float(float_buffer, temp_buffer, curr_batch);
} else if constexpr (is_bfloat16) {
__bang_bfloat162float(float_buffer, temp_buffer, curr_batch);
}
}
// Common processing for all types
if (is_exp_phase) {
__bang_sub_scalar(float_buffer, float_buffer, scalar, curr_batch); // scalar is max_val
__bang_active_exphp(float_buffer, float_buffer, curr_batch);
} else {
__bang_mul_scalar(float_buffer, float_buffer, scalar, curr_batch); // scalar is 1.0f/sum_val
}
// Convert back and scatter output using 2D memcpy
if constexpr (is_float) {
__memcpy(output + processed * stride, float_buffer, sizeof(float),
NRAM2GDRAM, dst_stride, sizeof(float), curr_batch - 1);
} else {
// Convert back to original type
if constexpr (is_half) {
__bang_float2half(temp_buffer, float_buffer, curr_batch);
} else if constexpr (is_bfloat16) {
__bang_float2bfloat16(temp_buffer, float_buffer, curr_batch);
}
// Scatter output
__memcpy(output + processed * stride, temp_buffer, sizeof(T),
NRAM2GDRAM, dst_stride, sizeof(T), curr_batch - 1);
}
processed += curr_batch;
}
}
template <typename T>
__mlu_global__ void causalSoftmax(T *y, const T *x,
size_t batch_size, size_t seq_len, size_t total_seq_len,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_i, ptrdiff_t y_stride_j,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_i, ptrdiff_t x_stride_j) {
using namespace op::common_bang::reduce_op;
// Get task information
size_t task_id = taskId;
size_t task_num = taskDimX * taskDimY;
// Calculate elements per task with better load balancing
size_t total_tasks = batch_size * seq_len;
size_t tasks_per_core = (total_tasks + task_num - 1) / task_num;
size_t start = task_id * tasks_per_core;
size_t end = std::min(start + tasks_per_core, total_tasks);
// Allocate NRAM buffers
const int max_batch = SRC_MAX_SIZE / sizeof(T);
T *src = (T *)nram_buffer;
float *dst = (float *)(nram_buffer + max_batch * sizeof(T));
for (size_t index = start; index < end; index++) {
size_t batch = index / seq_len;
size_t i = (index % seq_len);
ptrdiff_t y_offset = batch * y_stride_b + i * y_stride_i;
ptrdiff_t x_offset = batch * x_stride_b + i * x_stride_i;
T *y_ = y + y_offset;
const T *x_ = x + x_offset;
// Calculate the valid sequence length for this position
size_t valid_len = total_seq_len - seq_len + i + 1;
// Zero out future positions
for (size_t j = valid_len; j < total_seq_len; j++) {
y_[j * y_stride_j] = (T)0.0f;
}
// Calculate max value using optimized reduction
float max_val = maxBatched(x_, src, dst, valid_len, max_batch);
// Compute exp(x - max)
processSoftmaxStep(y_, x_, max_val, valid_len, x_stride_j, true);
// Calculate sum of exponentials
float sum_val = sumBatched(y_, src, dst, valid_len, max_batch);
// Normalize by sum
processSoftmaxStep(y_, y_, 1.0f / sum_val, valid_len, y_stride_j, false);
}
}
template <typename T>
void causalSoftmaxUnion(void *workspace, int core_per_cluster, int cluster_count,
cnrtQueue_t queue, void *y, const void *x, const op::causal_softmax::CausalSoftmaxInfo *info) {
cnrtDim3_t kernel_dim;
cnrtFunctionType_t kernel_type;
// Configure kernel dimensions
kernel_dim.x = core_per_cluster;
kernel_dim.y = cluster_count;
kernel_dim.z = 1;
kernel_type = CNRT_FUNC_TYPE_UNION1;
// Launch kernel
causalSoftmax<T><<<kernel_dim, kernel_type, queue>>>(
(T *)y, (const T *)x,
info->batch_size, info->seq_len, info->total_seq_len,
info->y_stride_b, info->y_stride_i, info->y_stride_j,
info->x_stride_b, info->x_stride_i, info->x_stride_j);
cnrtQueueSync(queue);
}
namespace op::causal_softmax::bang {
struct Descriptor::Opaque {
std::shared_ptr<device::bang::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
auto handle = reinterpret_cast<device::bang::cambricon::Handle *>(handle_);
auto result = CausalSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(result);
auto info = result.take();
*desc_ptr = new Descriptor(
new Descriptor::Opaque{static_cast<device::bang::Handle *>(handle)->internal()},
info,
0,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y,
const void *x,
void *stream) const {
auto queue = reinterpret_cast<cnrtQueue_t>(stream);
int core_per_cluster = _opaque->internal->getCorePerCluster();
int cluster_count = _opaque->internal->getClusterCount();
// Dispatch based on data type
if (_info.dtype == INFINI_DTYPE_F16) {
causalSoftmaxUnion<half>(workspace, core_per_cluster, cluster_count, queue, y, x, &_info);
} else if (_info.dtype == INFINI_DTYPE_BF16) {
causalSoftmaxUnion<bfloat16_t>(workspace, core_per_cluster, cluster_count, queue, y, x, &_info);
} else if (_info.dtype == INFINI_DTYPE_F32) {
causalSoftmaxUnion<float>(workspace, core_per_cluster, cluster_count, queue, y, x, &_info);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::causal_softmax::bang
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
#include "ascend/causal_softmax_ascend.h" #include "ascend/causal_softmax_ascend.h"
#endif #endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/causal_softmax_bang.h"
#endif
__C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor( __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -39,22 +42,14 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor( ...@@ -39,22 +42,14 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia); CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang)
#endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax) CREATE(INFINI_DEVICE_METAX, metax)
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend) CREATE(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangCreateCausalSoftmaxDescriptor((BangHandle_t)handle, (CausalSoftmaxBangDescriptor_t *)desc_ptr, y_desc);
// return cnnlCreateCausalSoftmaxDescriptor((BangHandle_t) handle, (CausalSoftmaxCnnlDescriptor_t *) desc_ptr, y_desc);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaCreateCausalSoftmaxDescriptor((MusaHandle_t)handle, (CausalSoftmaxMusaDescriptor_t *)desc_ptr, y_desc);
}
#endif #endif
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -83,17 +78,8 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe ...@@ -83,17 +78,8 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend) GET(INFINI_DEVICE_ASCEND, ascend)
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_API
case DevCambriconMlu: { GET(INFINI_DEVICE_CAMBRICON, bang)
return bangGetCausalSoftmaxWorkspaceSize((CausalSoftmaxBangDescriptor_t)desc, size);
// return cnnlGetCausalSoftmaxWorkspaceSize((CausalSoftmaxCnnlDescriptor_t) desc, size);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaGetCausalSoftmaxWorkspaceSize((CausalSoftmaxMusaDescriptor_t)desc, size);
}
#endif #endif
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -121,22 +107,14 @@ __C infiniStatus_t infiniopCausalSoftmax( ...@@ -121,22 +107,14 @@ __C infiniStatus_t infiniopCausalSoftmax(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang)
#endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax) CALCULATE(INFINI_DEVICE_METAX, metax)
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend) CALCULATE(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangCausalSoftmax((CausalSoftmaxBangDescriptor_t)desc, workspace, workspace_size, data, stream);
// return cnnlCausalSoftmax((CausalSoftmaxCnnlDescriptor_t) desc, workspace, workspace_size, data, stream);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaCausalSoftmax((CausalSoftmaxMusaDescriptor_t)desc, workspace, workspace_size, data, stream);
}
#endif #endif
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -159,21 +137,14 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD ...@@ -159,21 +137,14 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_CAMBRICON_API
DESTROY(INFINI_DEVICE_CAMBRICON, bang)
#endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax) DESTROY(INFINI_DEVICE_METAX, metax)
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
DESTROY(INFINI_DEVICE_ASCEND, ascend) DESTROY(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangDestroyCausalSoftmaxDescriptor((CausalSoftmaxBangDescriptor_t)desc);
// return cnnlDestroyCausalSoftmaxDescriptor((CausalSoftmaxCnnlDescriptor_t) desc);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu:
return musaDestroyCausalSoftmaxDescriptor((CausalSoftmaxMusaDescriptor_t)desc);
#endif #endif
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#ifndef __RMS_NORM_BANG_H__
#define __RMS_NORM_BANG_H__
#include "../rms_norm.h"
DESCRIPTOR(bang)
#endif // __RMS_NORM_BANG_H__
#include "../../../devices/bang/common_bang.h"
#include "../../../reduce/bang/reduce_bang.h"
#include "rms_norm_bang.h"
__nram__ char nram_buffer[NRAM_MAX_SIZE];
const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 4;
template <typename T, typename Tw>
__mlu_global__ void rmsnorm(T *output, const T *input, const Tw *weight,
size_t *shape, ptrdiff_t *output_strides, ptrdiff_t *input_strides,
float epsilon, int num_dims, int norm_dim_size) {
// Calculate problem dimensions
int batch_volume = 1;
for (int dim_idx = 0; dim_idx < num_dims - 1; ++dim_idx) {
batch_volume *= shape[dim_idx];
}
int vector_size = shape[num_dims - 1];
// Determine maximum batch size for NRAM operations
int max_batch_size = (vector_size >= SRC_MAX_SIZE / sizeof(Tw) ? SRC_MAX_SIZE / sizeof(Tw) : norm_dim_size);
constexpr int reduce_buffer_size = 128 / sizeof(float);
// Task distribution across cores
int remaining_tasks = batch_volume % taskDim;
int base_tasks_per_core = batch_volume / taskDim;
int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
int task_start_idx = (taskId < remaining_tasks ? taskId * base_tasks_per_core + taskId : taskId * base_tasks_per_core + remaining_tasks);
// NRAM buffer allocation
int half_type_offset = (sizeof(T) == 2 ? max_batch_size : 0);
char *input_buffer = nram_buffer + reduce_buffer_size * sizeof(float);
char *weight_buffer = input_buffer + (max_batch_size + half_type_offset) * sizeof(T);
float *reduction_result = (float *)nram_buffer;
T *input_cache = (T *)input_buffer;
Tw *weight_cache = (Tw *)weight_buffer;
// Process vectors assigned to current core
int processed_tasks = 0;
while (processed_tasks < actual_tasks) {
int input_offset = 0;
int output_offset = 0;
int current_index = task_start_idx + processed_tasks;
// Calculate memory offsets for current task
for (int dim = num_dims - 2; dim >= 0; --dim) {
input_offset += (current_index % shape[dim]) * input_strides[dim];
output_offset += (current_index % shape[dim]) * output_strides[dim];
current_index = current_index / shape[dim];
}
// Compute sum of squares
__bang_write_zero(reduction_result, reduce_buffer_size);
float sum_squared = op::common_bang::reduce_op::sumSquaredBatched<T>(
input + input_offset, input_cache, reduction_result, vector_size, max_batch_size);
// Compute normalization factor
float rms_value = sum_squared / vector_size;
rms_value += epsilon;
rms_value = sqrtf(rms_value);
float inv_rms = 1.0f / rms_value;
// Process vector in chunks
size_t processed_elements = 0;
while (processed_elements < vector_size) {
size_t current_batch = std::min((size_t)max_batch_size, vector_size - processed_elements);
// Load data
__memcpy(input_cache, input + input_offset + processed_elements, current_batch * sizeof(T), GDRAM2NRAM);
__memcpy(weight_cache, weight + processed_elements, current_batch * sizeof(Tw), GDRAM2NRAM);
// Normalization and scaling
if constexpr (std::is_same<T, bfloat16_t>::value && std::is_same<Tw, float>::value) {
// Special handling for BF16 input with F32 weights
__bang_bfloat162float((float *)input_cache, input_cache, current_batch);
__bang_mul((float *)input_cache, (float *)input_cache, weight_cache, current_batch);
__bang_mul_scalar((float *)input_cache, (float *)input_cache, inv_rms, current_batch);
__bang_float2bfloat16(input_cache, (float *)input_cache, current_batch);
} else {
if constexpr (std::is_same<T, half>::value && std::is_same<Tw, float>::value) {
__bang_float2half_dn((T *)weight_cache, weight_cache, current_batch);
}
__bang_mul(input_cache, input_cache, (T *)weight_cache, current_batch);
__bang_mul_scalar(input_cache, input_cache, inv_rms, current_batch);
}
// Store results
__memcpy(output + output_offset + processed_elements, input_cache, current_batch * sizeof(T), NRAM2GDRAM);
processed_elements += current_batch;
}
processed_tasks++;
}
}
template <typename T, typename Tw>
void rmsnormUnion(void *workspace, int core_per_cluster, int cluster_count, cnrtQueue_t queue, void *y, const void *x, const void *w, const size_t *shape, const ptrdiff_t *y_strides, const ptrdiff_t *x_strides, float eps, int ndim) {
cnrtDim3_t kernel_dim;
cnrtFunctionType_t kernel_type;
// Configure kernel dimensions
kernel_dim.x = core_per_cluster;
kernel_dim.y = cluster_count;
kernel_dim.z = 1;
kernel_type = CNRT_FUNC_TYPE_UNION1; // Can choose others, but must adapt kernel_type accordingly
int dimsize = shape[ndim - 1]; // Length of operation dimension
int dim_s; // dim_s is the next power of 2 greater than dimsize
float mi = log2(dimsize);
if (floor(mi) == mi) {
dim_s = dimsize;
} else {
dim_s = pow(2, floor(mi) + 1);
}
constexpr int reduce_num = 128 / sizeof(float); // Cambricon __bang_reduce_sum can only reduce 128 bytes at a time
if (dim_s < reduce_num) {
dim_s = reduce_num; // Force dim_s >= reduce_num
}
// Prepare device pointers
auto y_ = reinterpret_cast<T *>(y);
auto x_ = reinterpret_cast<const T *>(x);
auto w_ = reinterpret_cast<const Tw *>(w);
char *tmp_device = reinterpret_cast<char *>(workspace);
char *tmp_stride = tmp_device + ndim * sizeof(size_t);
size_t *mlu_shape = (size_t *)tmp_device;
ptrdiff_t *mlu_x_strides = (ptrdiff_t *)tmp_stride;
ptrdiff_t *mlu_y_strides = mlu_x_strides + ndim;
// Copy shape and stride information to device
CNRT_CHECK(cnrtMemcpyAsync(mlu_shape, const_cast<size_t *>(shape), ndim * sizeof(size_t), queue, cnrtMemcpyHostToDev)); // const not supported
CNRT_CHECK(cnrtMemcpyAsync(mlu_x_strides, const_cast<ptrdiff_t *>(x_strides), ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtMemcpyAsync(mlu_y_strides, const_cast<ptrdiff_t *>(y_strides), ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));
// Launch kernel
rmsnorm<T, Tw><<<kernel_dim, kernel_type, queue>>>(y_, x_, w_, mlu_shape, mlu_y_strides, mlu_x_strides, eps, ndim, dim_s);
cnrtQueueSync(queue);
}
namespace op::rms_norm::bang {
struct Descriptor::Opaque {
std::shared_ptr<device::bang::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto handle = reinterpret_cast<device::bang::cambricon::Handle *>(handle_);
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size = info.ndim() * (sizeof(size_t) + 2 * sizeof(ptrdiff_t));
*desc_ptr = new Descriptor(
new Descriptor::Opaque{static_cast<device::bang::Handle *>(handle)->internal()},
info,
workspace_size,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) const {
auto queue = reinterpret_cast<cnrtQueue_t>(stream);
int core_per_cluster = _opaque->internal->getCorePerCluster();
int cluster_count = _opaque->internal->getClusterCount();
// Dispatch based on data types
if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) {
rmsnormUnion<half, half>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else if (_info.wtype == INFINI_DTYPE_F32) {
rmsnormUnion<half, float>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_F32) {
if (_info.wtype == INFINI_DTYPE_F32) {
rmsnormUnion<float, float>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_BF16) {
if (_info.wtype == INFINI_DTYPE_BF16) {
rmsnormUnion<bfloat16_t, bfloat16_t>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else if (_info.wtype == INFINI_DTYPE_F32) {
rmsnormUnion<bfloat16_t, float>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::bang
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
#include "ascend/rms_norm_aclnn.h" #include "ascend/rms_norm_aclnn.h"
#endif #endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/rms_norm_bang.h"
#endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
#include "metax/rms_norm_metax.cuh" #include "metax/rms_norm_metax.cuh"
#endif #endif
...@@ -52,10 +55,8 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor( ...@@ -52,10 +55,8 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun); CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_API
case DevCambriconMlu: { CREATE(INFINI_DEVICE_CAMBRICON, bang);
return bangCreateRMSNormDescriptor((BangHandle_t)handle, (RMSNormBangDescriptor_t *)desc_ptr, y_desc, x_desc, w_desc, epsilon);
}
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend); CREATE(INFINI_DEVICE_ASCEND, ascend);
...@@ -93,10 +94,8 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d ...@@ -93,10 +94,8 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun); GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_API
case DevCambriconMlu: { GET(INFINI_DEVICE_CAMBRICON, bang);
return bangGetRMSNormWorkspaceSize((RMSNormBangDescriptor_t)desc, size);
}
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend); GET(INFINI_DEVICE_ASCEND, ascend);
...@@ -135,10 +134,8 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works ...@@ -135,10 +134,8 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_API
case DevCambriconMlu: { CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
return bangRMSNorm((RMSNormBangDescriptor_t)desc, workspace, workspace_size, y, x, w, stream);
}
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend); CALCULATE(INFINI_DEVICE_ASCEND, ascend);
...@@ -176,10 +173,8 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t ...@@ -176,10 +173,8 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
DESTROY(INFINI_DEVICE_KUNLUN, kunlun); DESTROY(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_API
case DevCambriconMlu: { DESTROY(INFINI_DEVICE_CAMBRICON, bang);
return bangDestroyRMSNormDescriptor((RMSNormBangDescriptor_t)desc);
}
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
DESTROY(INFINI_DEVICE_ASCEND, ascend); DESTROY(INFINI_DEVICE_ASCEND, ascend);
......
#ifndef __INFINIOP_REDUCE_BANG_H__
#define __INFINIOP_REDUCE_BANG_H__
#include "../../devices/bang/common_bang.h"
namespace op::common_bang::reduce_op {
constexpr int batch_size = 128 / sizeof(float);
__mlu_func__ void sumInternal(float *dst, float *src, int max_batch) {
const int width = max_batch / batch_size;
// Use vectorized reduction
if (width >= 4) {
__bang_sumpool(
dst, src,
batch_size, 1, width,
1, width, 1, 1);
__bang_reduce_sum(dst, dst, batch_size);
} else {
// Fallback for small batches
float sum = 0.0f;
for (int i = 0; i < max_batch; ++i) {
sum += src[i];
}
dst[0] = sum;
}
}
template <typename T>
__mlu_func__ void sumTyped(float *result, T *data, size_t len) {
if constexpr (std::is_same_v<T, half>) {
__bang_half2float((float *)data, data + len, len);
sumInternal(result, (float *)data, len);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
__bang_bfloat162float((float *)data, data + len, len);
sumInternal(result, (float *)data, len);
} else {
sumInternal(result, data, len);
}
}
template <typename T>
__mlu_func__ float sum(const T *source, T *src, float *dst, int num_elements, int max_batch) {
float res = 0.0f;
int offset = (sizeof(T) == 2 ? max_batch : 0);
size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
sumTyped(dst, src, max_batch);
res += dst[0];
processed += curr_batch;
}
return res;
}
template <typename T>
__mlu_func__ float sumBatched(const T *source, T *src, float *dst, int num_elements, int max_batch) {
constexpr int min_vector_size = 32;
// For small vectors, use safer element-wise computation
if (num_elements < min_vector_size) {
return sum(source, src, dst, num_elements, max_batch);
}
float res = 0.0f;
int offset = (sizeof(T) == 2 ? max_batch : 0);
size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
size_t aligned_batch = (curr_batch / batch_size) * batch_size;
size_t remainder = curr_batch % batch_size;
// Ensure NRAM buffer is zeroed
__bang_write_zero(src, max_batch + offset);
// Copy data to NRAM
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
if constexpr (std::is_same_v<T, half>) {
__bang_half2float((float *)(src + offset), src + offset, curr_batch);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
__bang_bfloat162float((float *)(src + offset), src + offset, curr_batch);
}
// Process aligned portion
if (aligned_batch > 0) {
sumInternal(dst, (float *)(src + offset), aligned_batch);
res += dst[0];
}
// Process unaligned tail
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
res += ((float *)(src + offset))[i];
}
}
processed += curr_batch;
}
return res;
}
template <typename T>
__mlu_func__ float sumSquared(const T *source, T *src, float *dst, int num_elements, int max_batch) {
float res = 0.0f;
int offset = (sizeof(T) == 2 ? max_batch : 0);
size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
float sum = 0.0f;
for (size_t i = 0; i < curr_batch; ++i) {
float val = 0.0f;
if constexpr (std::is_same_v<T, half>) {
val = __half2float(src[offset + i]);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
val = __bfloat162float(src[offset + i]);
} else {
val = src[offset + i];
}
sum += val * val;
}
res += sum;
processed += curr_batch;
}
return res;
}
template <typename T>
__mlu_func__ float sumSquaredBatched(const T *source, T *src, float *dst, int num_elements, int max_batch) {
constexpr int min_vector_size = 32;
// For small vectors, use safer element-wise computation
if (num_elements < min_vector_size) {
return sumSquared(source, src, dst, num_elements, max_batch);
}
float res = 0.0f;
int offset = (sizeof(T) == 2 ? max_batch : 0);
size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
size_t aligned_batch = (curr_batch / batch_size) * batch_size;
size_t remainder = curr_batch % batch_size;
// Ensure NRAM buffer is zeroed
__bang_write_zero(src, max_batch + offset);
// Copy data to NRAM
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
if constexpr (std::is_same_v<T, half>) {
__bang_half2float((float *)(src + offset), src + offset, curr_batch);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
__bang_bfloat162float((float *)(src + offset), src + offset, curr_batch);
}
// Process aligned portion
if (aligned_batch > 0) {
__bang_mul((float *)(src + offset), (float *)(src + offset),
(float *)(src + offset), aligned_batch);
sumInternal(dst, (float *)(src + offset), aligned_batch);
res += dst[0];
}
// Process unaligned tail
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
float val = ((float *)(src + offset))[i];
res += val * val;
}
}
processed += curr_batch;
}
return res;
}
__mlu_func__ void maxInternal(float *dst, float *src, int max_batch) {
__bang_maxpool(
dst, src,
batch_size, // channel size
1, // height
max_batch / batch_size, // width
1, // kernel_height
max_batch / batch_size, // kernel_width
1, // stride_height
1 // stride_width
);
__bang_argmax(dst, dst, batch_size);
}
template <typename T>
__mlu_func__ void maxTyped(float *result, T *data, size_t len) {
if constexpr (std::is_same_v<T, half>) {
__bang_half2float((float *)data, data + len, len);
maxInternal(result, (float *)data, len);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
__bang_bfloat162float((float *)data, data + len, len);
maxInternal(result, (float *)data, len);
} else {
maxInternal(result, data, len);
}
}
template <typename T>
__mlu_func__ float max(const T *source, T *src, float *dst, int num_elements, int max_batch) {
float max_val = -INFINITY;
int offset = (sizeof(T) == 2 ? max_batch : 0);
size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
maxTyped(dst, src, max_batch);
max_val = std::max(max_val, dst[0]);
processed += curr_batch;
}
return max_val;
}
template <typename T>
__mlu_func__ float maxBatched(const T *source, T *src, float *dst, int num_elements, int max_batch) {
constexpr int min_vector_size = 32;
// For small vectors, use safer element-wise computation
if (num_elements < min_vector_size) {
return max(source, src, dst, num_elements, max_batch);
}
float max_val = -INFINITY;
int offset = (sizeof(T) == 2 ? max_batch : 0);
size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
maxTyped(dst, src, max_batch);
max_val = std::max(max_val, dst[0]);
processed += curr_batch;
}
return max_val;
}
} // namespace op::common_bang::reduce_op
#endif // __INFINIOP_REDUCE_BANG_H__
...@@ -41,7 +41,7 @@ _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] ...@@ -41,7 +41,7 @@ _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2}, InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2},
InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, InfiniDtype.F32: {"atol": 3e-5, "rtol": 1e-5},
} }
......
...@@ -46,7 +46,7 @@ _TEST_CASES = [ ...@@ -46,7 +46,7 @@ _TEST_CASES = [
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 2e-3, "rtol": 2e-3}, InfiniDtype.F16: {"atol": 2e-3, "rtol": 2e-3},
InfiniDtype.BF16: {"atol": 8e-3, "rtol": 8e-3}, InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2},
} }
DEBUG = False DEBUG = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment