reduce.cc 1.59 KB
Newer Older
1
2
3
4
#include "reduce.h"

namespace op::common_cpu::reduce_op {

5
6
template <typename HalfType>
float sum_half_impl(const HalfType *data, size_t len, ptrdiff_t stride) {
7
8
9
10
11
12
13
    float result = 0;
    for (size_t i = 0; i < len; i++) {
        result += utils::cast<float>(data[i * stride]);
    }
    return result;
}

14
15
template <typename HalfType>
float max_half_impl(const HalfType *data, size_t len, ptrdiff_t stride) {
16
17
18
19
20
21
22
    float result = utils::cast<float>(data[0]);
    for (size_t i = 1; i < len; i++) {
        result = std::max(result, utils::cast<float>(data[i * stride]));
    }
    return result;
}

23
24
template <typename HalfType>
float sumSquared_half_impl(const HalfType *data, size_t len, ptrdiff_t stride) {
25
26
27
28
29
30
31
32
    float result = 0;
    for (size_t i = 0; i < len; i++) {
        float val = utils::cast<float>(data[i * stride]);
        result += val * val;
    }
    return result;
}

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
// fp16
float sum(const fp16_t *data, size_t len, ptrdiff_t stride) {
    return sum_half_impl(data, len, stride);
}

float max(const fp16_t *data, size_t len, ptrdiff_t stride) {
    return max_half_impl(data, len, stride);
}

float sumSquared(const fp16_t *data, size_t len, ptrdiff_t stride) {
    return sumSquared_half_impl(data, len, stride);
}

// bf16
float sum(const bf16_t *data, size_t len, ptrdiff_t stride) {
    return sum_half_impl(data, len, stride);
}

float max(const bf16_t *data, size_t len, ptrdiff_t stride) {
    return max_half_impl(data, len, stride);
}

float sumSquared(const bf16_t *data, size_t len, ptrdiff_t stride) {
    return sumSquared_half_impl(data, len, stride);
}

59
} // namespace op::common_cpu::reduce_op