device_reduce_instance_threadwise.hpp 9.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_HPP
#define DEVICE_REDUCE_INSTANCE_THREADWISE_HPP

#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_impl_common.hpp"
#include "device_reduce_threadwise.hpp"

namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {

#ifdef QUICK_REDUCE_TEST
using reduce_configuration_2_instances_threadwise = std::tuple<
    // clang-format off
    // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
    ReductionConfiguration_2<0, 2, 2, 2, 1>,
    ReductionConfiguration_2<0, 1, 1, 2, 1>,
    ReductionConfiguration_2<1, 2, 1, 1, 2>,
    ReductionConfiguration_2<1, 2, 2, 1, 2>,
    ReductionConfiguration_2<0, 1, 1, 3, 1>,
    ReductionConfiguration_2<1, 1, 1, 1, 3>
    // clang-format on
    >;
#else
using reduce_configuration_2_instances_threadwise = std::tuple<
    // clang-format off
    // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
    ReductionConfiguration_2<0, 4, 4, 8, 1>,
    ReductionConfiguration_2<0, 4, 4, 4, 1>,
    ReductionConfiguration_2<0, 2, 2, 2, 1>,

    ReductionConfiguration_2<1, 4, 1, 1, 8>,
    ReductionConfiguration_2<1, 4, 1, 1, 4>,
    ReductionConfiguration_2<1, 2, 1, 1, 2>,

    // special instances
    ReductionConfiguration_2<0, 1, 1, 3, 1>,
    ReductionConfiguration_2<0, 1, 1, 5, 1>,
    ReductionConfiguration_2<0, 1, 1, 7, 1>,
    ReductionConfiguration_2<0, 1, 1, 11, 1>,

    ReductionConfiguration_2<1, 1, 1, 1, 3>,
    ReductionConfiguration_2<1, 1, 1, 1, 5>,
    ReductionConfiguration_2<1, 1, 1, 1, 7>,
    ReductionConfiguration_2<1, 1, 1, 1, 11>
    // clang-format on
    >;
#endif

template <typename AccDataType, ReduceTensorOp_t ReduceOpId>
using deviceReduceThreadWisePtrType = DeviceReducePtr<
    typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation,
    typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>;

template <typename InDataType,
          typename AccDataType,
          typename OutDataType,
          int Rank,
          typename ReduceDims,
          ReduceTensorOp_t ReduceOpId,
          NanPropagation_t NanOpt,
          ReduceTensorIndices_t IndicesOpt>
void add_device_reduce_instance_threadwise(
    std::vector<deviceReduceThreadWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
{
    using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
    using InElementwiseOperation =
        typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
    using AccElementwiseOperation =
        typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
            AccElementwiseOperation;

    constexpr bool Indexable =
        (ReduceOpId == ReduceTensorOp_t::MIN || ReduceOpId == ReduceTensorOp_t::MAX ||
         ReduceOpId == ReduceTensorOp_t::AMAX);
    constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices_t::NO_INDICES);

    constexpr bool PropagateNan = (NanOpt == NanPropagation_t::NOT_PROPAGATE_NAN) ? false : true;

    using cfg1 = ReductionConfiguration_1<256, 256, 1>;

    static_for<0, std::tuple_size<reduce_configuration_2_instances_threadwise>::value, 1>{}(
        [&](auto j) {
            using cfg2 = remove_cvref_t<decltype(
                std::get<j.value>(reduce_configuration_2_instances_threadwise{}))>;

            using ReduceOpInstance = DeviceReduceThreadWise<InDataType,
                                                            AccDataType,
                                                            OutDataType,
                                                            Rank,
                                                            ReduceDims,
                                                            ReduceOperation,
                                                            InElementwiseOperation,
                                                            AccElementwiseOperation,
                                                            PropagateNan,
                                                            NeedIndices,
                                                            cfg1::BlockSize_,
                                                            cfg1::MThreadClusterSize_,
                                                            cfg1::KThreadClusterSize_,
                                                            cfg2::MThreadSliceSize_,
                                                            cfg2::KThreadSliceSize_,
                                                            cfg2::InSrcVectorDim_,
                                                            cfg2::InSrcVectorSize_,
                                                            cfg2::OutDstVectorSize_>;

            device_op_instances.push_back(std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
        });
};

#define ADD_THREADWISE_INST_BY_TYPE(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
    template void add_device_reduce_instance_threadwise<inT,                                     \
                                                        compT,                                   \
                                                        outT,                                    \
                                                        Rank,                                    \
                                                        Sequence<__VA_ARGS__>,                   \
                                                        ReduceOpId,                              \
                                                        NanOpt,                                  \
                                                        IndicesOpt>(                             \
        std::vector<deviceReduceThreadWisePtrType<compT, ReduceOpId>> & device_op_instances)

#define ADD_THREADWISE_INST_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
    ADD_THREADWISE_INST_BY_TYPE(inT,                                                           \
                                compT,                                                         \
                                outT,                                                          \
                                static_cast<ReduceTensorOp_t>(ReduceOpId),                     \
                                static_cast<NanPropagation_t>(NanOpt),                         \
                                static_cast<ReduceTensorIndices_t>(IndicesOpt),                \
                                Rank,                                                          \
                                __VA_ARGS__)

#define ADD_THREADWISE_INST_REF_BY_TYPE(                                                           \
    inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...)                                   \
    extern template void add_device_reduce_instance_threadwise<inT,                                \
                                                               compT,                              \
                                                               outT,                               \
                                                               Rank,                               \
                                                               Sequence<__VA_ARGS__>,              \
                                                               ReduceOpId,                         \
                                                               NanOpt,                             \
                                                               IndicesOpt>(                        \
        std::vector<DeviceReducePtr<                                                               \
            typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
            typename reduce_unary_operator<compT, ReduceOpId, true, true>::                        \
                AccElementwiseOperation>> &                                                        \
        device_op_instances)

#define ADD_THREADWISE_INST_REF_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
    ADD_THREADWISE_INST_REF_BY_TYPE(inT,                                                           \
                                    compT,                                                         \
                                    outT,                                                          \
                                    static_cast<ReduceTensorOp_t>(ReduceOpId),                     \
                                    static_cast<NanPropagation_t>(NanOpt),                         \
                                    static_cast<ReduceTensorIndices_t>(IndicesOpt),                \
                                    Rank,                                                          \
                                    __VA_ARGS__)

} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation

} // namespace ck

#endif