Unverified Commit 3a91947e authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Issue/459 (#460)

* issue/459 - Support more data type combinations

* issue/459 - added test cases for 9G7B and 9G70B

* issue/459 - modified rms kernel to support larger tensors
parent 2a81c8bd
......@@ -3,7 +3,6 @@
#include "rms_norm_bang.h"
__nram__ char nram_buffer[NRAM_MAX_SIZE];
const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 4;
template <typename T, typename Tw>
__mlu_global__ void rmsnorm(T *output, const T *input, const Tw *weight,
......@@ -16,80 +15,202 @@ __mlu_global__ void rmsnorm(T *output, const T *input, const Tw *weight,
}
int vector_size = shape[num_dims - 1];
// Determine maximum batch size for NRAM operations
int max_batch_size = (vector_size >= SRC_MAX_SIZE / sizeof(Tw) ? SRC_MAX_SIZE / sizeof(Tw) : norm_dim_size);
constexpr int reduce_buffer_size = 128 / sizeof(float);
// Task distribution across cores
int remaining_tasks = batch_volume % taskDim;
int base_tasks_per_core = batch_volume / taskDim;
int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0);
int task_start_idx = (taskId < remaining_tasks ? taskId * base_tasks_per_core + taskId : taskId * base_tasks_per_core + remaining_tasks);
int task_start_idx = (taskId < remaining_tasks ? taskId * (base_tasks_per_core + 1) : remaining_tasks * (base_tasks_per_core + 1) + (taskId - remaining_tasks) * base_tasks_per_core);
// Determine optimal batch size based on vector size
int max_batch_size;
if (vector_size <= 64) {
// For small vectors, process the entire vector at once
max_batch_size = vector_size;
} else {
// For larger vectors, use optimized batch size
max_batch_size = (NRAM_MAX_SIZE - 256) / (2 * sizeof(T) + sizeof(Tw) + sizeof(float));
max_batch_size = std::min(max_batch_size, vector_size);
max_batch_size = (max_batch_size / 64) * 64; // Align to 64 elements
}
// NRAM buffer allocation
int half_type_offset = (sizeof(T) == 2 ? max_batch_size : 0);
char *input_buffer = nram_buffer + reduce_buffer_size * sizeof(float);
char *weight_buffer = input_buffer + (max_batch_size + half_type_offset) * sizeof(T);
constexpr int reduce_buffer_size = 128 / sizeof(float);
float *reduction_result = (float *)nram_buffer;
T *input_cache = (T *)input_buffer;
Tw *weight_cache = (Tw *)weight_buffer;
// NRAM buffer allocation with dynamic sizing
float *reduction_buffer = (float *)nram_buffer;
T *input_cache = (T *)(reduction_buffer + reduce_buffer_size);
Tw *weight_cache = (Tw *)(input_cache + max_batch_size);
float *float_buffer = (float *)(weight_cache + max_batch_size);
float *weight_float_buffer = (float *)(float_buffer + max_batch_size);
// Process vectors assigned to current core
int processed_tasks = 0;
while (processed_tasks < actual_tasks) {
for (int task_idx = 0; task_idx < actual_tasks; ++task_idx) {
int current_index = task_start_idx + task_idx;
// Calculate memory offsets for current task
int input_offset = 0;
int output_offset = 0;
int current_index = task_start_idx + processed_tasks;
int temp_index = current_index;
// Calculate memory offsets for current task
for (int dim = num_dims - 2; dim >= 0; --dim) {
input_offset += (current_index % shape[dim]) * input_strides[dim];
output_offset += (current_index % shape[dim]) * output_strides[dim];
current_index = current_index / shape[dim];
for (int dim = 0; dim < num_dims - 1; ++dim) {
int dim_coord = temp_index % shape[dim];
input_offset += dim_coord * input_strides[dim];
output_offset += dim_coord * output_strides[dim];
temp_index /= shape[dim];
}
// Compute sum of squares
__bang_write_zero(reduction_result, reduce_buffer_size);
float sum_squared = op::common_bang::reduce_op::sumSquaredBatched<T>(
input + input_offset, input_cache, reduction_result, vector_size, max_batch_size);
float sum_squared = 0.0f;
if (vector_size <= 128) {
// Small vector optimization: process entire vector at once
__memcpy(input_cache, input + input_offset, vector_size * sizeof(T), GDRAM2NRAM);
// Convert to float and square
if constexpr (std::is_same<T, half>::value) {
__bang_half2float(float_buffer, input_cache, vector_size);
} else if constexpr (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(float_buffer, input_cache, vector_size);
} else {
__memcpy(float_buffer, input_cache, vector_size * sizeof(float), NRAM2NRAM);
}
__bang_mul(float_buffer, float_buffer, float_buffer, vector_size);
// Direct accumulation for small vectors
for (int i = 0; i < vector_size; ++i) {
sum_squared += float_buffer[i];
}
} else {
// Large vector processing with chunking
__bang_write_zero(reduction_buffer, reduce_buffer_size);
size_t processed_elements = 0;
while (processed_elements < vector_size) {
size_t current_batch = std::min((size_t)max_batch_size, vector_size - processed_elements);
// Load input data
__memcpy(input_cache, input + input_offset + processed_elements * input_strides[num_dims - 1],
current_batch * sizeof(T), GDRAM2NRAM);
// Convert to float and square
if constexpr (std::is_same<T, half>::value) {
__bang_half2float(float_buffer, input_cache, current_batch);
} else if constexpr (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(float_buffer, input_cache, current_batch);
} else {
__memcpy(float_buffer, input_cache, current_batch * sizeof(float), NRAM2NRAM);
}
__bang_mul(float_buffer, float_buffer, float_buffer, current_batch);
// Accumulate squared values
float batch_sum = 0.0f;
if (current_batch >= 128) {
op::common_bang::reduce_op::sumInternal(reduction_buffer, float_buffer, current_batch);
batch_sum = reduction_buffer[0];
} else {
for (size_t i = 0; i < current_batch; ++i) {
batch_sum += float_buffer[i];
}
}
sum_squared += batch_sum;
processed_elements += current_batch;
}
}
// Compute normalization factor
float rms_value = sum_squared / vector_size;
rms_value += epsilon;
rms_value = sqrtf(rms_value);
float rms_value = sqrtf(sum_squared / vector_size + epsilon);
float inv_rms = 1.0f / rms_value;
// Process vector in chunks
size_t processed_elements = 0;
while (processed_elements < vector_size) {
size_t current_batch = std::min((size_t)max_batch_size, vector_size - processed_elements);
// Load data
__memcpy(input_cache, input + input_offset + processed_elements, current_batch * sizeof(T), GDRAM2NRAM);
__memcpy(weight_cache, weight + processed_elements, current_batch * sizeof(Tw), GDRAM2NRAM);
// Normalization and scaling
if constexpr (std::is_same<T, bfloat16_t>::value && std::is_same<Tw, float>::value) {
// Special handling for BF16 input with F32 weights
__bang_bfloat162float((float *)input_cache, input_cache, current_batch);
__bang_mul((float *)input_cache, (float *)input_cache, weight_cache, current_batch);
__bang_mul_scalar((float *)input_cache, (float *)input_cache, inv_rms, current_batch);
__bang_float2bfloat16(input_cache, (float *)input_cache, current_batch);
// Process vector for normalization
if (vector_size <= max_batch_size) {
// Process entire vector at once for small vectors
__memcpy(input_cache, input + input_offset, vector_size * sizeof(T), GDRAM2NRAM);
__memcpy(weight_cache, weight, vector_size * sizeof(Tw), GDRAM2NRAM);
// Convert input to float
if constexpr (std::is_same<T, half>::value) {
__bang_half2float(float_buffer, input_cache, vector_size);
} else if constexpr (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(float_buffer, input_cache, vector_size);
} else {
if constexpr (std::is_same<T, half>::value && std::is_same<Tw, float>::value) {
__bang_float2half_dn((T *)weight_cache, weight_cache, current_batch);
}
__bang_mul(input_cache, input_cache, (T *)weight_cache, current_batch);
__bang_mul_scalar(input_cache, input_cache, inv_rms, current_batch);
__memcpy(float_buffer, input_cache, vector_size * sizeof(float), NRAM2NRAM);
}
// Convert weight to float if needed
if constexpr (std::is_same<Tw, half>::value) {
__bang_half2float(weight_float_buffer, weight_cache, vector_size);
} else if constexpr (std::is_same<Tw, bfloat16_t>::value) {
__bang_bfloat162float(weight_float_buffer, weight_cache, vector_size);
} else {
__memcpy(weight_float_buffer, weight_cache, vector_size * sizeof(float), NRAM2NRAM);
}
// Multiply by weight and apply normalization
__bang_mul(float_buffer, float_buffer, weight_float_buffer, vector_size);
__bang_mul_scalar(float_buffer, float_buffer, inv_rms, vector_size);
// Convert back to output type
if constexpr (std::is_same<T, half>::value) {
__bang_float2half(input_cache, float_buffer, vector_size);
} else if constexpr (std::is_same<T, bfloat16_t>::value) {
__bang_float2bfloat16(input_cache, float_buffer, vector_size);
} else {
__memcpy(input_cache, float_buffer, vector_size * sizeof(float), NRAM2NRAM);
}
// Store results
__memcpy(output + output_offset + processed_elements, input_cache, current_batch * sizeof(T), NRAM2GDRAM);
__memcpy(output + output_offset, input_cache, vector_size * sizeof(T), NRAM2GDRAM);
} else {
// Large vector processing with chunking
size_t processed_elements = 0;
while (processed_elements < vector_size) {
size_t current_batch = std::min((size_t)max_batch_size, vector_size - processed_elements);
// Load input and weight data
__memcpy(input_cache, input + input_offset + processed_elements * input_strides[num_dims - 1],
current_batch * sizeof(T), GDRAM2NRAM);
__memcpy(weight_cache, weight + processed_elements, current_batch * sizeof(Tw), GDRAM2NRAM);
// Convert input to float
if constexpr (std::is_same<T, half>::value) {
__bang_half2float(float_buffer, input_cache, current_batch);
} else if constexpr (std::is_same<T, bfloat16_t>::value) {
__bang_bfloat162float(float_buffer, input_cache, current_batch);
} else {
__memcpy(float_buffer, input_cache, current_batch * sizeof(float), NRAM2NRAM);
}
processed_elements += current_batch;
}
// Convert weight to float if needed
if constexpr (std::is_same<Tw, half>::value) {
__bang_half2float(weight_float_buffer, weight_cache, current_batch);
} else if constexpr (std::is_same<Tw, bfloat16_t>::value) {
__bang_bfloat162float(weight_float_buffer, weight_cache, current_batch);
} else {
__memcpy(weight_float_buffer, weight_cache, current_batch * sizeof(float), NRAM2NRAM);
}
processed_tasks++;
// Multiply by weight and apply normalization
__bang_mul(float_buffer, float_buffer, weight_float_buffer, current_batch);
__bang_mul_scalar(float_buffer, float_buffer, inv_rms, current_batch);
// Convert back to output type
if constexpr (std::is_same<T, half>::value) {
__bang_float2half(input_cache, float_buffer, current_batch);
} else if constexpr (std::is_same<T, bfloat16_t>::value) {
__bang_float2bfloat16(input_cache, float_buffer, current_batch);
} else {
__memcpy(input_cache, float_buffer, current_batch * sizeof(float), NRAM2NRAM);
}
// Store results
__memcpy(output + output_offset + processed_elements * output_strides[num_dims - 1],
input_cache, current_batch * sizeof(T), NRAM2GDRAM);
processed_elements += current_batch;
}
}
}
}
......@@ -178,18 +299,24 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
int core_per_cluster = _opaque->internal->getCorePerCluster();
int cluster_count = _opaque->internal->getClusterCount();
// Dispatch based on data types
// Dispatch based on data types - support all combinations
if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) {
rmsnormUnion<half, half>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else if (_info.wtype == INFINI_DTYPE_F32) {
rmsnormUnion<half, float>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else if (_info.wtype == INFINI_DTYPE_BF16) {
rmsnormUnion<half, bfloat16_t>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_F32) {
if (_info.wtype == INFINI_DTYPE_F32) {
rmsnormUnion<float, float>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else if (_info.wtype == INFINI_DTYPE_F16) {
rmsnormUnion<float, half>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else if (_info.wtype == INFINI_DTYPE_BF16) {
rmsnormUnion<float, bfloat16_t>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......@@ -198,6 +325,8 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
rmsnormUnion<bfloat16_t, bfloat16_t>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else if (_info.wtype == INFINI_DTYPE_F32) {
rmsnormUnion<bfloat16_t, float>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else if (_info.wtype == INFINI_DTYPE_F16) {
rmsnormUnion<bfloat16_t, half>(workspace, core_per_cluster, cluster_count, queue, y, x, w, _info.shape.data(), _info.y_strides.data(), _info.x_strides.data(), _info.epsilon, _info.ndim());
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -34,6 +34,8 @@ _TEST_CASES_ = [
((4, 4, 2048), (4, 4, 2048), (2048,), None, None),
((4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1)),
((4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1)),
((15, 3584), (15, 3584), (3584,), None, None),
((15, 8192), (15, 8192), (8192,), None, None),
]
# w (weight) types
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment