Commit 7e3ade06 authored by wooway777's avatar wooway777
Browse files

issue/240 - added bf16 support to cambricon rms norm and adjusted tolerance

parent ef577d9d
#include "../../../devices/bang/common_bang.h"
#include "../../../reduce/bang/reduce_bang.h"
#include "rms_norm_bang.h"
__nram__ char nram_buffer[NRAM_MAX_SIZE];
const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 4;
template <typename T, typename Tw>
__mlu_global__ void rmsnorm(T *output, const T *input, const Tw *weight,
size_t *shape, ptrdiff_t *output_strides, ptrdiff_t *input_strides,
float epsilon, int num_dims, int norm_dim_size) {
// Calculate problem dimensions
int batch_volume = 1;
for (int dim_idx = 0; dim_idx < num_dims - 1; ++dim_idx) {
batch_volume *= shape[dim_idx];
}
int vector_size = shape[num_dims - 1];
// Determine maximum batch size for NRAM operations
int max_batch_size = (vector_size >= SRC_MAX_SIZE / sizeof(Tw) ? SRC_MAX_SIZE / sizeof(Tw) : norm_dim_size);
constexpr int reduce_buffer_size = 128 / sizeof(float);
// Task distribution across cores
int remaining_tasks = batch_volume % taskDim;
int base_tasks_per_core = batch_volume / taskDim;
int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
int task_start_idx = (taskId < remaining_tasks ? taskId * base_tasks_per_core + taskId : taskId * base_tasks_per_core + remaining_tasks);
// NRAM buffer allocation
int half_type_offset = (sizeof(T) == 2 ? max_batch_size : 0);
char *input_buffer = nram_buffer + reduce_buffer_size * sizeof(float);
char *weight_buffer = input_buffer + (max_batch_size + half_type_offset) * sizeof(T);
float *reduction_result = (float *)nram_buffer;
T *input_cache = (T *)input_buffer;
Tw *weight_cache = (Tw *)weight_buffer;
// Process vectors assigned to current core
int processed_tasks = 0;
while (processed_tasks < actual_tasks) {
int input_offset = 0;
int output_offset = 0;
int current_index = task_start_idx + processed_tasks;
// Calculate memory offsets for current task
for (int dim = num_dims - 2; dim >= 0; --dim) {
input_offset += (current_index % shape[dim]) * input_strides[dim];
output_offset += (current_index % shape[dim]) * output_strides[dim];
current_index = current_index / shape[dim];
}
// Compute sum of squares
__bang_write_zero(reduction_result, reduce_buffer_size);
float sum_squared = op::common_bang::reduce_op::sumSquaredBatched<T>(
input + input_offset, input_cache, reduction_result, vector_size, max_batch_size);
// Compute normalization factor
float rms_value = sum_squared / vector_size;
rms_value += epsilon;
rms_value = sqrtf(rms_value);
float inv_rms = 1.0f / rms_value;
// Process vector in chunks
size_t processed_elements = 0;
while (processed_elements < vector_size) {
size_t current_batch = std::min((size_t)max_batch_size, vector_size - processed_elements);
// Load data
__memcpy(input_cache, input + input_offset + processed_elements, current_batch * sizeof(T), GDRAM2NRAM);
__memcpy(weight_cache, weight + processed_elements, current_batch * sizeof(Tw), GDRAM2NRAM);
// Normalization and scaling
if constexpr (std::is_same<T, bfloat16_t>::value && std::is_same<Tw, float>::value) {
// Special handling for BF16 input with F32 weights
__bang_bfloat162float((float *)input_cache, input_cache, current_batch);
__bang_mul((float *)input_cache, (float *)input_cache, weight_cache, current_batch);
__bang_mul_scalar((float *)input_cache, (float *)input_cache, inv_rms, current_batch);
__bang_float2bfloat16(input_cache, (float *)input_cache, current_batch);
} else {
if constexpr (std::is_same<T, half>::value && std::is_same<Tw, float>::value) {
__bang_float2half_dn((T *)weight_cache, weight_cache, current_batch);
}
__bang_mul(input_cache, input_cache, (T *)weight_cache, current_batch);
__bang_mul_scalar(input_cache, input_cache, inv_rms, current_batch);
}
// Store results
__memcpy(output + output_offset + processed_elements, input_cache, current_batch * sizeof(T), NRAM2GDRAM);
processed_elements += current_batch;
}
processed_tasks++;
}
}
template <typename T, typename Tw>
void rmsnormUnion(void *workspace, int core_per_cluster, int cluster_count, cnrtQueue_t queue, void *y, const void *x, const void *w, const size_t *shape, const ptrdiff_t *y_strides, const ptrdiff_t *x_strides, float eps, int ndim) {
cnrtDim3_t kernel_dim;
cnrtFunctionType_t kernel_type;
// Configure kernel dimensions
kernel_dim.x = core_per_cluster;
kernel_dim.y = cluster_count;
kernel_dim.z = 1;
kernel_type = CNRT_FUNC_TYPE_UNION1; // Can choose others, but must adapt kernel_type accordingly
int dimsize = shape[ndim - 1]; // Length of operation dimension
int dim_s; // dim_s is the next power of 2 greater than dimsize
float mi = log2(dimsize);
if (floor(mi) == mi) {
dim_s = dimsize;
} else {
dim_s = pow(2, floor(mi) + 1);
}
constexpr int reduce_num = 128 / sizeof(float); // Cambricon __bang_reduce_sum can only reduce 128 bytes at a time
if (dim_s < reduce_num) {
dim_s = reduce_num; // Force dim_s >= reduce_num
}
// Prepare device pointers
auto y_ = reinterpret_cast<T *>(y);
auto x_ = reinterpret_cast<const T *>(x);
auto w_ = reinterpret_cast<const Tw *>(w);
char *tmp_device = reinterpret_cast<char *>(workspace);
char *tmp_stride = tmp_device + ndim * sizeof(size_t);
size_t *mlu_shape = (size_t *)tmp_device;
ptrdiff_t *mlu_x_strides = (ptrdiff_t *)tmp_stride;
ptrdiff_t *mlu_y_strides = mlu_x_strides + ndim;
// Copy shape and stride information to device
CNRT_CHECK(cnrtMemcpyAsync(mlu_shape, const_cast<size_t *>(shape), ndim * sizeof(size_t), queue, cnrtMemcpyHostToDev)); // const not supported
CNRT_CHECK(cnrtMemcpyAsync(mlu_x_strides, const_cast<ptrdiff_t *>(x_strides), ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtMemcpyAsync(mlu_y_strides, const_cast<ptrdiff_t *>(y_strides), ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));
// Launch kernel
rmsnorm<T, Tw><<<kernel_dim, kernel_type, queue>>>(y_, x_, w_, mlu_shape, mlu_y_strides, mlu_x_strides, eps, ndim, dim_s);
cnrtQueueSync(queue);
}
namespace op::rms_norm::bang {
struct Descriptor::Opaque {
std::shared_ptr<device::bang::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto handle = reinterpret_cast<device::bang::cambricon::Handle *>(handle_);
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size = info.ndim() * (sizeof(size_t) + 2 * sizeof(ptrdiff_t));
*desc_ptr = new Descriptor(
new Descriptor::Opaque{static_cast<device::bang::Handle *>(handle)->internal()},
info,
workspace_size,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) const {
auto queue = reinterpret_cast<cnrtQueue_t>(stream);
int core_per_cluster = _opaque->internal->getCorePerCluster();
int cluster_count = _opaque->internal->getClusterCount();
// Dispatch based on data types
if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) {
rmsnormUnion<half, half>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else if (_info.wtype == INFINI_DTYPE_F32) {
rmsnormUnion<half, float>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_F32) {
if (_info.wtype == INFINI_DTYPE_F32) {
rmsnormUnion<float, float>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_BF16) {
if (_info.wtype == INFINI_DTYPE_BF16) {
rmsnormUnion<bfloat16_t, bfloat16_t>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else if (_info.wtype == INFINI_DTYPE_F32) {
rmsnormUnion<bfloat16_t, float>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::bang
#ifndef __INFINIOP_REDUCE_BANG_H__
#define __INFINIOP_REDUCE_BANG_H__
#include "../../devices/bang/common_bang.h"
namespace op::common_bang::reduce_op {
constexpr int batch_size = 128 / sizeof(float);
__mlu_func__ void sumInternal(float *dst, float *src, int max_batch) {
const int width = max_batch / batch_size;
// Use vectorized reduction
if (width >= 4) {
__bang_sumpool(
dst, src,
batch_size, 1, width,
1, width, 1, 1);
__bang_reduce_sum(dst, dst, batch_size);
} else {
// Fallback for small batches
float sum = 0.0f;
for (int i = 0; i < max_batch; ++i) {
sum += src[i];
}
dst[0] = sum;
}
}
template <typename T>
__mlu_func__ void sumTyped(float *result, T *data, size_t len) {
if constexpr (std::is_same_v<T, half>) {
__bang_half2float((float *)data, data + len, len);
sumInternal(result, (float *)data, len);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
__bang_bfloat162float((float *)data, data + len, len);
sumInternal(result, (float *)data, len);
} else {
sumInternal(result, data, len);
}
}
template <typename T>
__mlu_func__ float sum(const T *source, T *src, float *dst, int num_elements, int max_batch) {
float res = 0.0f;
int offset = (sizeof(T) == 2 ? max_batch : 0);
size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
sumTyped(dst, src, max_batch);
res += dst[0];
processed += curr_batch;
}
return res;
}
template <typename T>
__mlu_func__ float sumBatched(const T *source, T *src, float *dst, int num_elements, int max_batch) {
constexpr int min_vector_size = 32;
// For small vectors, use safer element-wise computation
if (num_elements < min_vector_size) {
return sum(source, src, dst, num_elements, max_batch);
}
float res = 0.0f;
int offset = (sizeof(T) == 2 ? max_batch : 0);
size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
size_t aligned_batch = (curr_batch / batch_size) * batch_size;
size_t remainder = curr_batch % batch_size;
// Ensure NRAM buffer is zeroed
__bang_write_zero(src, max_batch + offset);
// Copy data to NRAM
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
if constexpr (std::is_same_v<T, float>) {
// Process aligned portion
if (aligned_batch > 0) {
sumInternal(dst, (float *)(src + offset), aligned_batch);
res += dst[0];
}
// Process unaligned tail
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
res += ((float *)(src + offset))[i];
}
}
} else {
// half/bfloat16 processing path
if constexpr (std::is_same_v<T, half>) {
__bang_half2float((float *)(src + offset), src + offset, curr_batch);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
__bang_bfloat162float((float *)(src + offset), src + offset, curr_batch);
}
// Process aligned portion
if (aligned_batch > 0) {
sumInternal(dst, (float *)(src + offset), aligned_batch);
res += dst[0];
}
// Process unaligned tail
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
res += ((float *)(src + offset))[i];
}
}
}
processed += curr_batch;
}
return res;
}
template <typename T>
__mlu_func__ float sumSquared(const T *source, T *src, float *dst, int num_elements, int max_batch) {
float res = 0.0f;
int offset = (sizeof(T) == 2 ? max_batch : 0);
size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
// Find max absolute value
float max_val = 0.0f;
for (size_t i = 0; i < curr_batch; ++i) {
float val = 0.0f;
if constexpr (std::is_same_v<T, half>) {
val = fabs(__half2float(src[offset + i]));
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
val = fabs(__bfloat162float(src[offset + i]));
} else {
val = fabs(src[offset + i]);
}
max_val = std::max(val, max_val);
}
float scale = (max_val > 1e3f) ? 1e3f / max_val : 1.0f; // Prevent overflow
float sum = 0.0f;
// Scaled computation
for (size_t i = 0; i < curr_batch; ++i) {
float val = 0.0f;
if constexpr (std::is_same_v<T, half>) {
val = __half2float(src[offset + i]) * scale;
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
val = __bfloat162float(src[offset + i]) * scale;
} else {
val = src[offset + i] * scale;
}
sum += val * val;
}
res += sum / (scale * scale);
processed += curr_batch;
}
return res;
}
template <typename T>
__mlu_func__ float sumSquaredBatched(const T *source, T *src, float *dst, int num_elements, int max_batch) {
constexpr int min_vector_size = 32;
// For small vectors, use safer element-wise computation
if (num_elements < min_vector_size) {
return sumSquared(source, src, dst, num_elements, max_batch);
}
float res = 0.0f;
int offset = (sizeof(T) == 2 ? max_batch : 0);
size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
size_t aligned_batch = (curr_batch / batch_size) * batch_size;
size_t remainder = curr_batch % batch_size;
// Ensure NRAM buffer is zeroed
__bang_write_zero(src, max_batch + offset);
// Copy data to NRAM
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
if constexpr (std::is_same_v<T, float>) {
// float32 processing path
if (aligned_batch > 0) {
__bang_mul((float *)(src + offset), (float *)(src + offset),
(float *)(src + offset), aligned_batch);
sumInternal(dst, (float *)(src + offset), aligned_batch);
res += dst[0];
}
// Process unaligned tail
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
float val = ((float *)(src + offset))[i];
res += val * val;
}
}
} else {
// half/bfloat16 processing path
if constexpr (std::is_same_v<T, half>) {
__bang_half2float((float *)(src + offset), src + offset, curr_batch);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
__bang_bfloat162float((float *)(src + offset), src + offset, curr_batch);
}
// Find maximum absolute value
float max_val = 0.0f;
if (aligned_batch > 0) {
__bang_abs((float *)(src + offset), (float *)(src + offset), aligned_batch);
sumInternal(dst, (float *)(src + offset), aligned_batch);
max_val = dst[0] / (aligned_batch / batch_size);
}
// Check for max value in tail elements
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
float val = fabs(((float *)(src + offset))[i]);
max_val = std::max(max_val, val);
}
}
// Scale and compute squared sum
float scale = (max_val > 1e3f) ? 1e3f / max_val : 1.0f;
// Process aligned portion
if (aligned_batch > 0) {
__bang_mul_scalar((float *)(src + offset), (float *)(src + offset), scale, aligned_batch);
__bang_mul((float *)(src + offset), (float *)(src + offset),
(float *)(src + offset), aligned_batch);
sumInternal(dst, (float *)(src + offset), aligned_batch);
res += dst[0] / (scale * scale);
}
// Process unaligned tail
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
float val = ((float *)(src + offset))[i] * scale;
res += val * val / (scale * scale);
}
}
}
processed += curr_batch;
}
return res;
}
__mlu_func__ void maxInternal(float *dst, float *src, int max_batch) {
__bang_maxpool(
dst, src,
batch_size, // channel size
1, // height
max_batch / batch_size, // width
1, // kernel_height
max_batch / batch_size, // kernel_width
1, // stride_height
1 // stride_width
);
__bang_argmax(dst, dst, batch_size);
}
template <typename T>
__mlu_func__ void maxTyped(float *result, T *data, size_t len) {
if constexpr (std::is_same_v<T, half>) {
__bang_half2float((float *)data, data + len, len);
maxInternal(result, (float *)data, len);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
__bang_bfloat162float((float *)data, data + len, len);
maxInternal(result, (float *)data, len);
} else {
maxInternal(result, data, len);
}
}
template <typename T>
__mlu_func__ float max(const T *source, T *src, float *dst, int num_elements, int max_batch) {
float max_val = -INFINITY;
int offset = (sizeof(T) == 2 ? max_batch : 0);
size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
maxTyped(dst, src, max_batch);
max_val = std::max(max_val, dst[0]);
processed += curr_batch;
}
return max_val;
}
template <typename T>
__mlu_func__ float maxBatched(const T *source, T *src, float *dst, int num_elements, int max_batch) {
constexpr int min_vector_size = 32;
// For small vectors, use safer element-wise computation
if (num_elements < min_vector_size) {
return max(source, src, dst, num_elements, max_batch);
}
float max_val = -INFINITY;
int offset = (sizeof(T) == 2 ? max_batch : 0);
size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
maxTyped(dst, src, max_batch);
max_val = std::max(max_val, dst[0]);
processed += curr_batch;
}
return max_val;
}
} // namespace op::common_bang::reduce_op
#endif // __INFINIOP_REDUCE_BANG_H__
......@@ -46,7 +46,7 @@ _TEST_CASES = [
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 2e-3, "rtol": 2e-3},
InfiniDtype.BF16: {"atol": 8e-3, "rtol": 8e-3},
InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2},
}
DEBUG = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment