Commit 724d9692 authored by wooway777's avatar wooway777
Browse files

issue/240 - removed some redundant operations

parent 64b5a2bc
......@@ -86,20 +86,6 @@ __mlu_func__ float sumBatched(const T *source, T *src, float *dst, int num_eleme
// Copy data to NRAM
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
if constexpr (std::is_same_v<T, float>) {
// Process aligned portion
if (aligned_batch > 0) {
sumInternal(dst, (float *)(src + offset), aligned_batch);
res += dst[0];
}
// Process unaligned tail
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
res += ((float *)(src + offset))[i];
}
}
} else {
// half/bfloat16 processing path
if constexpr (std::is_same_v<T, half>) {
__bang_half2float((float *)(src + offset), src + offset, curr_batch);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
......@@ -117,7 +103,6 @@ __mlu_func__ float sumBatched(const T *source, T *src, float *dst, int num_eleme
res += ((float *)(src + offset))[i];
}
}
}
processed += curr_batch;
}
......@@ -140,37 +125,21 @@ __mlu_func__ float sumSquared(const T *source, T *src, float *dst, int num_eleme
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
// Find max absolute value
float max_val = 0.0f;
for (size_t i = 0; i < curr_batch; ++i) {
float val = 0.0f;
if constexpr (std::is_same_v<T, half>) {
val = fabs(__half2float(src[offset + i]));
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
val = fabs(__bfloat162float(src[offset + i]));
} else {
val = fabs(src[offset + i]);
}
max_val = std::max(val, max_val);
}
float scale = (max_val > 1e3f) ? 1e3f / max_val : 1.0f; // Prevent overflow
float sum = 0.0f;
// Scaled computation
for (size_t i = 0; i < curr_batch; ++i) {
float val = 0.0f;
if constexpr (std::is_same_v<T, half>) {
val = __half2float(src[offset + i]) * scale;
val = __half2float(src[offset + i]);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
val = __bfloat162float(src[offset + i]) * scale;
val = __bfloat162float(src[offset + i]);
} else {
val = src[offset + i] * scale;
val = src[offset + i];
}
sum += val * val;
}
res += sum / (scale * scale);
res += sum;
processed += curr_batch;
}
......@@ -201,64 +170,25 @@ __mlu_func__ float sumSquaredBatched(const T *source, T *src, float *dst, int nu
// Copy data to NRAM
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
if constexpr (std::is_same_v<T, float>) {
// float32 processing path
if (aligned_batch > 0) {
__bang_mul((float *)(src + offset), (float *)(src + offset),
(float *)(src + offset), aligned_batch);
sumInternal(dst, (float *)(src + offset), aligned_batch);
res += dst[0];
}
// Process unaligned tail
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
float val = ((float *)(src + offset))[i];
res += val * val;
}
}
} else {
// half/bfloat16 processing path
if constexpr (std::is_same_v<T, half>) {
__bang_half2float((float *)(src + offset), src + offset, curr_batch);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
__bang_bfloat162float((float *)(src + offset), src + offset, curr_batch);
}
// Find maximum absolute value
float max_val = 0.0f;
if (aligned_batch > 0) {
__bang_abs((float *)(src + offset), (float *)(src + offset), aligned_batch);
sumInternal(dst, (float *)(src + offset), aligned_batch);
max_val = dst[0] / (aligned_batch / batch_size);
}
// Check for max value in tail elements
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
float val = fabs(((float *)(src + offset))[i]);
max_val = std::max(max_val, val);
}
}
// Scale and compute squared sum
float scale = (max_val > 1e3f) ? 1e3f / max_val : 1.0f;
// Process aligned portion
if (aligned_batch > 0) {
__bang_mul_scalar((float *)(src + offset), (float *)(src + offset), scale, aligned_batch);
__bang_mul((float *)(src + offset), (float *)(src + offset),
(float *)(src + offset), aligned_batch);
sumInternal(dst, (float *)(src + offset), aligned_batch);
res += dst[0] / (scale * scale);
res += dst[0];
}
// Process unaligned tail
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
float val = ((float *)(src + offset))[i] * scale;
res += val * val / (scale * scale);
}
float val = ((float *)(src + offset))[i];
res += val * val;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment