"...git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "f57d86f4d1fdadffb08e7d882a94718df123bb46"
Commit 724d9692 authored by wooway777's avatar wooway777
Browse files

issue/240 - removed some redundant operations

parent 64b5a2bc
...@@ -86,36 +86,21 @@ __mlu_func__ float sumBatched(const T *source, T *src, float *dst, int num_eleme ...@@ -86,36 +86,21 @@ __mlu_func__ float sumBatched(const T *source, T *src, float *dst, int num_eleme
// Copy data to NRAM // Copy data to NRAM
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM); __memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
if constexpr (std::is_same_v<T, float>) { if constexpr (std::is_same_v<T, half>) {
// Process aligned portion __bang_half2float((float *)(src + offset), src + offset, curr_batch);
if (aligned_batch > 0) { } else if constexpr (std::is_same_v<T, bfloat16_t>) {
sumInternal(dst, (float *)(src + offset), aligned_batch); __bang_bfloat162float((float *)(src + offset), src + offset, curr_batch);
res += dst[0]; }
}
// Process unaligned tail
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
res += ((float *)(src + offset))[i];
}
}
} else {
// half/bfloat16 processing path
if constexpr (std::is_same_v<T, half>) {
__bang_half2float((float *)(src + offset), src + offset, curr_batch);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
__bang_bfloat162float((float *)(src + offset), src + offset, curr_batch);
}
// Process aligned portion // Process aligned portion
if (aligned_batch > 0) { if (aligned_batch > 0) {
sumInternal(dst, (float *)(src + offset), aligned_batch); sumInternal(dst, (float *)(src + offset), aligned_batch);
res += dst[0]; res += dst[0];
} }
// Process unaligned tail // Process unaligned tail
if (remainder > 0) { if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) { for (size_t i = aligned_batch; i < curr_batch; ++i) {
res += ((float *)(src + offset))[i]; res += ((float *)(src + offset))[i];
}
} }
} }
...@@ -140,37 +125,21 @@ __mlu_func__ float sumSquared(const T *source, T *src, float *dst, int num_eleme ...@@ -140,37 +125,21 @@ __mlu_func__ float sumSquared(const T *source, T *src, float *dst, int num_eleme
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM); __memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
// Find max absolute value
float max_val = 0.0f;
for (size_t i = 0; i < curr_batch; ++i) {
float val = 0.0f;
if constexpr (std::is_same_v<T, half>) {
val = fabs(__half2float(src[offset + i]));
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
val = fabs(__bfloat162float(src[offset + i]));
} else {
val = fabs(src[offset + i]);
}
max_val = std::max(val, max_val);
}
float scale = (max_val > 1e3f) ? 1e3f / max_val : 1.0f; // Prevent overflow
float sum = 0.0f; float sum = 0.0f;
// Scaled computation
for (size_t i = 0; i < curr_batch; ++i) { for (size_t i = 0; i < curr_batch; ++i) {
float val = 0.0f; float val = 0.0f;
if constexpr (std::is_same_v<T, half>) { if constexpr (std::is_same_v<T, half>) {
val = __half2float(src[offset + i]) * scale; val = __half2float(src[offset + i]);
} else if constexpr (std::is_same_v<T, bfloat16_t>) { } else if constexpr (std::is_same_v<T, bfloat16_t>) {
val = __bfloat162float(src[offset + i]) * scale; val = __bfloat162float(src[offset + i]);
} else { } else {
val = src[offset + i] * scale; val = src[offset + i];
} }
sum += val * val; sum += val * val;
} }
res += sum / (scale * scale); res += sum;
processed += curr_batch; processed += curr_batch;
} }
...@@ -201,64 +170,25 @@ __mlu_func__ float sumSquaredBatched(const T *source, T *src, float *dst, int nu ...@@ -201,64 +170,25 @@ __mlu_func__ float sumSquaredBatched(const T *source, T *src, float *dst, int nu
// Copy data to NRAM // Copy data to NRAM
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM); __memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
if constexpr (std::is_same_v<T, float>) { if constexpr (std::is_same_v<T, half>) {
// float32 processing path __bang_half2float((float *)(src + offset), src + offset, curr_batch);
if (aligned_batch > 0) { } else if constexpr (std::is_same_v<T, bfloat16_t>) {
__bang_mul((float *)(src + offset), (float *)(src + offset), __bang_bfloat162float((float *)(src + offset), src + offset, curr_batch);
(float *)(src + offset), aligned_batch); }
sumInternal(dst, (float *)(src + offset), aligned_batch);
res += dst[0];
}
// Process unaligned tail
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
float val = ((float *)(src + offset))[i];
res += val * val;
}
}
} else {
// half/bfloat16 processing path
if constexpr (std::is_same_v<T, half>) {
__bang_half2float((float *)(src + offset), src + offset, curr_batch);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
__bang_bfloat162float((float *)(src + offset), src + offset, curr_batch);
}
// Find maximum absolute value
float max_val = 0.0f;
if (aligned_batch > 0) {
__bang_abs((float *)(src + offset), (float *)(src + offset), aligned_batch);
sumInternal(dst, (float *)(src + offset), aligned_batch);
max_val = dst[0] / (aligned_batch / batch_size);
}
// Check for max value in tail elements
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
float val = fabs(((float *)(src + offset))[i]);
max_val = std::max(max_val, val);
}
}
// Scale and compute squared sum
float scale = (max_val > 1e3f) ? 1e3f / max_val : 1.0f;
// Process aligned portion // Process aligned portion
if (aligned_batch > 0) { if (aligned_batch > 0) {
__bang_mul_scalar((float *)(src + offset), (float *)(src + offset), scale, aligned_batch); __bang_mul((float *)(src + offset), (float *)(src + offset),
__bang_mul((float *)(src + offset), (float *)(src + offset), (float *)(src + offset), aligned_batch);
(float *)(src + offset), aligned_batch); sumInternal(dst, (float *)(src + offset), aligned_batch);
sumInternal(dst, (float *)(src + offset), aligned_batch); res += dst[0];
res += dst[0] / (scale * scale); }
}
// Process unaligned tail // Process unaligned tail
if (remainder > 0) { if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) { for (size_t i = aligned_batch; i < curr_batch; ++i) {
float val = ((float *)(src + offset))[i] * scale; float val = ((float *)(src + offset))[i];
res += val * val / (scale * scale); res += val * val;
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment