conv_util.hpp 6.25 KB
Newer Older
1
#pragma once
2
3
4
5

#include <tuple>

#include "config.hpp"
6
#include "data_type.hpp"
7
8
9
10
11
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "host_tensor.hpp"
#include "sequence.hpp"

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
namespace ck {
namespace tensor_operation {
namespace device {

using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<element_wise::PassThrough,
                                              element_wise::PassThrough,
                                              element_wise::PassThrough>;
namespace device_conv2d_fwd_instance {

void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);

} // namespace device_conv2d_fwd_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

31
32
namespace test {
namespace conv {
33
34
35
36
37
38
39
40

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using InElementOp  = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;

41
42
43
using DeviceConvFwdNoOpPtr =
    ck::tensor_operation::device::DeviceConvFwdPtr<InElementOp, WeiElementOp, OutElementOp>;

44
45
46
static constexpr auto ConvFwdDefault =
    ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

47
48
49
50
51
template <ck::index_t SpatialDims,
          typename InDataType,
          typename WeiDataType,
          typename OutDataType,
          typename AccDataType>
52
53
54
55
56
57
using DeviceConvNDFwdInstance = ck::tensor_operation::device::
    DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
        // clang-format off
        InDataType,         // 
        WeiDataType,        //
        OutDataType,        //
58
        AccDataType,        // Accumulator data type.
59
60
61
62
63
        InElementOp,        // Input Elementwise Operation
        WeiElementOp,       // Weights Elementwise Operation
        OutElementOp,       // Output Elementwise Operation
        ConvFwdDefault,     // ConvForwardSpecialization
        SpatialDims,        // SptialDims
64
65
66
        256,                // BlockSize
        128,                // MPerBlock
        256,                // NPerBlock
67
        4,                  // K0PerBlock
68
69
70
71
72
73
        8,                  // K1
        32,                 // MPerXdl
        32,                 // NPerXdl
        2,                  // MXdlPerWave
        4,                  // NXdlPerWave
        S<4, 64, 1>,        // ABlockTransferThreadClusterLengths_K0_M_K1
74
75
76
        S<1, 0, 2>,         // ABlockTransferThreadClusterArrangeOrder
        S<1, 0, 2>,         // ABlockTransferSrcAccessOrder
        2,                  // ABlockTransferSrcVectorDim
77
78
        8,                  // ABlockTransferSrcScalarPerVector
        8,                  // ABlockTransferDstScalarPerVector_K1
79
        true,               // ABlockLdsAddExtraM
80
        S<4, 64, 1>,        // BBlockTransferThreadClusterLengths_K0_N_K1
81
82
83
        S<1, 0, 2>,         // BBlockTransferThreadClusterArrangeOrder
        S<1, 0, 2>,         // BBlockTransferSrcAccessOrder
        2,                  // BBlockTransferSrcVectorDim
84
85
86
        8,                  // BBlockTransferSrcScalarPerVector
        8,                  // BBlockTransferDstScalarPerVector_K1
        true,               // BBlockLdsAddExtraN
87
        7,                  // CThreadTransferSrcDstVectorDim
88
        1>;                // CThreadTransferDstScalarPerVector
89
90
91
// clang-format on

template <ck::index_t NDim,
92
93
94
95
          typename InDataType,
          typename WeiDataType,
          typename OutDataType,
          typename AccDataType>
96
void get_test_convolution_fwd_instance(std::vector<DeviceConvFwdNoOpPtr>& instances)
97
{
98
99
    using ConvInstanceT =
        DeviceConvNDFwdInstance<NDim, InDataType, WeiDataType, OutDataType, AccDataType>;
100
    instances.emplace_back(std::make_unique<ConvInstanceT>());
101
102
}

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
// TODO (aosewski)
// Temporary solution to get all DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// instances. When switched over to DeviceConvNDFwdXdl for 2D remove ConvolutionNDFwdInstances
// structures.
template <typename InDataType, typename WeiDataType, typename OutDataType>
struct ConvolutionNDFwdInstances;

template <>
struct ConvolutionNDFwdInstances<float, float, float>
{
    static std::vector<DeviceConvFwdNoOpPtr> Get(std::size_t num_dim_spatial)
    {
        std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
        if(num_dim_spatial == 2)
        {
            ck::tensor_operation::device::device_conv2d_fwd_instance::
                add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
        }
        return conv_ptrs;
    }
};

template <>
struct ConvolutionNDFwdInstances<ck::half_t, ck::half_t, ck::half_t>
{
    static std::vector<DeviceConvFwdNoOpPtr> Get(std::size_t num_dim_spatial)
    {
        std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
        if(num_dim_spatial == 2)
        {
            ck::tensor_operation::device::device_conv2d_fwd_instance::
                add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
        }
        return conv_ptrs;
    }
};

template <>
struct ConvolutionNDFwdInstances<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t>
{
    static std::vector<DeviceConvFwdNoOpPtr> Get(std::size_t num_dim_spatial)
    {
        std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
        if(num_dim_spatial == 2)
        {
            ck::tensor_operation::device::device_conv2d_fwd_instance::
                add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
        }
        return conv_ptrs;
    }
};

template <>
struct ConvolutionNDFwdInstances<int8_t, int8_t, int8_t>
{
    static std::vector<DeviceConvFwdNoOpPtr> Get(std::size_t num_dim_spatial)
    {
        std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
        if(num_dim_spatial == 2)
        {
            ck::tensor_operation::device::device_conv2d_fwd_instance::
                add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
        }
        return conv_ptrs;
    }
};

170
171
} // namespace conv
} // namespace test