reduction_functions_accumulate.hpp 4.3 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

Chao Liu's avatar
Chao Liu committed
4
5
6
7
8
9
#pragma once

#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
10
11
12
13

namespace ck {
namespace detail {

14
15
16
17
18
19
// Check for NaN; guarantee NaNs are NOT propagated to result (i.e., ignore NaNs)
template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanIgnore
{
    __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
    {
20
        if(!ck::math::isnan(currVal))
21
22
23
24
        {
            ReduceOperation{}(accuVal, currVal);
        }
    };
25
26
27
28
29
30
31
    __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal, AccDataType currVal1)
    {
        if(!ck::math::isnan(currVal) && !ck::math::isnan(currVal1))
        {
            ReduceOperation{}(accuVal, currVal, currVal1);
        }
    };
32
33
};

34
35
template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck;
36

37
38
39
40
41
// Does not check for NaN; does not guarantee NaNs be propagated to result
// e.g., given that max(a, b) = a > b ? a : b
// then  max(NaN, 1) returns 1
//       max(1, NaN) returns NaN
// since any comparison involving NaNs returns false
42
43
template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
44
45
{
    // cppcheck-suppress constParameter
46
    __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
47
48
49
    {
        ReduceOperation{}(accuVal, currVal);
    };
50
51
52
53
    __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal, AccDataType currVal1)
    {
        ReduceOperation{}(accuVal, currVal,currVal1);
    };
54
55
};

56
// Check for NaN; guarantees NaNs be propagated to result
57
58
59
template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
{
60
    __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
61
    {
62
63
64
        using ck::math::isnan;

        if(isnan(currVal))
65
66
67
68
69
70
71
        {
            accuVal = currVal;
        }
        else
        {
            ReduceOperation{}(accuVal, currVal);
        };
72
    };
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal, AccDataType currVal1)
    {
        using ck::math::isnan;

        if(isnan(currVal))
        {
            accuVal = currVal;
        }
	else if(isnan(currVal1))
        {
            accuVal = currVal1;
        }
        else
        {
            ReduceOperation{}(accuVal, currVal, currVal1);
        };
    };
90
91
92
93
};

template <bool PropagateNan, typename ReduceOperation, typename AccDataType, typename IndexDataType>
struct AccumulateWithIndexAndNanCheck;
94

95
96
97
template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType>
{
98
    __host__ __device__ static inline void
99
    // cppcheck-suppress constParameter
100
101
102
103
    Calculate(AccDataType& accuVal,
              AccDataType currVal,
              IndexDataType& accuIndex,
              IndexDataType currIndex)
104
105
106
    {
        bool changed = false;

107
        ReduceOperation{}(accuVal, currVal, changed);
108
109
110
111
112
113

        if(changed)
            accuIndex = currIndex;
    };
};

114
115
template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType>
116
{
117
    // The method is called when the ReduceOperation is indexable and the user asked for indices
118
119
120
121
    __host__ __device__ static inline void Calculate(AccDataType& accuVal,
                                                     AccDataType currVal,
                                                     IndexDataType& accuIndex,
                                                     IndexDataType currIndex)
122
    {
123
124
125
        using ck::math::isnan;

        if(isnan(currVal))
126
127
128
129
130
131
132
133
        {
            accuVal   = currVal;
            accuIndex = currIndex;
        }
        else
        {
            bool changed = false;

134
            ReduceOperation{}(accuVal, currVal, changed);
135
136
137
138
139
140
141

            if(changed)
                accuIndex = currIndex;
        }
    };
};

Chao Liu's avatar
Chao Liu committed
142
143
} // namespace detail
} // namespace ck