conv_util.hpp 6.36 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
#pragma once
5
6
7

#include <tuple>

Chao Liu's avatar
Chao Liu committed
8
9
10
11
12
13
#include "ck/ck.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
14

15
16
17
18
19
20
21
namespace ck {
namespace tensor_operation {
namespace device {

using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<element_wise::PassThrough,
                                              element_wise::PassThrough,
                                              element_wise::PassThrough>;
22
namespace instance {
23
24
25
26
27
28

void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);

29
} // namespace instance
30
31
32
33
} // namespace device
} // namespace tensor_operation
} // namespace ck

34
35
namespace test {
namespace conv {
36
37
38
39
40
41
42
43

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using InElementOp  = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;

44
45
46
using DeviceConvFwdNoOpPtr =
    ck::tensor_operation::device::DeviceConvFwdPtr<InElementOp, WeiElementOp, OutElementOp>;

47
48
49
static constexpr auto ConvFwdDefault =
    ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

50
51
52
53
54
template <ck::index_t SpatialDims,
          typename InDataType,
          typename WeiDataType,
          typename OutDataType,
          typename AccDataType>
55
56
57
58
59
60
using DeviceConvNDFwdInstance = ck::tensor_operation::device::
    DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
        // clang-format off
        InDataType,         // 
        WeiDataType,        //
        OutDataType,        //
61
        AccDataType,        // Accumulator data type.
62
63
64
65
66
        InElementOp,        // Input Elementwise Operation
        WeiElementOp,       // Weights Elementwise Operation
        OutElementOp,       // Output Elementwise Operation
        ConvFwdDefault,     // ConvForwardSpecialization
        SpatialDims,        // SptialDims
67
68
69
        256,                // BlockSize
        128,                // MPerBlock
        256,                // NPerBlock
70
        4,                  // K0PerBlock
71
72
73
74
75
76
        8,                  // K1
        32,                 // MPerXdl
        32,                 // NPerXdl
        2,                  // MXdlPerWave
        4,                  // NXdlPerWave
        S<4, 64, 1>,        // ABlockTransferThreadClusterLengths_K0_M_K1
77
78
79
        S<1, 0, 2>,         // ABlockTransferThreadClusterArrangeOrder
        S<1, 0, 2>,         // ABlockTransferSrcAccessOrder
        2,                  // ABlockTransferSrcVectorDim
80
81
        8,                  // ABlockTransferSrcScalarPerVector
        8,                  // ABlockTransferDstScalarPerVector_K1
82
        true,               // ABlockLdsAddExtraM
83
        S<4, 64, 1>,        // BBlockTransferThreadClusterLengths_K0_N_K1
84
85
86
        S<1, 0, 2>,         // BBlockTransferThreadClusterArrangeOrder
        S<1, 0, 2>,         // BBlockTransferSrcAccessOrder
        2,                  // BBlockTransferSrcVectorDim
87
88
89
        8,                  // BBlockTransferSrcScalarPerVector
        8,                  // BBlockTransferDstScalarPerVector_K1
        true,               // BBlockLdsAddExtraN
90
        7,                  // CThreadTransferSrcDstVectorDim
91
        1>;                // CThreadTransferDstScalarPerVector
92
93
94
// clang-format on

template <ck::index_t NDim,
95
96
97
98
          typename InDataType,
          typename WeiDataType,
          typename OutDataType,
          typename AccDataType>
99
void get_test_convolution_fwd_instance(std::vector<DeviceConvFwdNoOpPtr>& instances)
100
{
101
102
    using ConvInstanceT =
        DeviceConvNDFwdInstance<NDim, InDataType, WeiDataType, OutDataType, AccDataType>;
103
    instances.emplace_back(std::make_unique<ConvInstanceT>());
104
105
}

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
// TODO (aosewski)
// Temporary solution to get all DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// instances. When switched over to DeviceConvNDFwdXdl for 2D remove ConvolutionNDFwdInstances
// structures.
template <typename InDataType, typename WeiDataType, typename OutDataType>
struct ConvolutionNDFwdInstances;

template <>
struct ConvolutionNDFwdInstances<float, float, float>
{
    static std::vector<DeviceConvFwdNoOpPtr> Get(std::size_t num_dim_spatial)
    {
        std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
        if(num_dim_spatial == 2)
        {
121
            ck::tensor_operation::device::instance::
122
123
124
125
126
127
128
129
130
131
132
133
134
135
                add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
        }
        return conv_ptrs;
    }
};

template <>
struct ConvolutionNDFwdInstances<ck::half_t, ck::half_t, ck::half_t>
{
    static std::vector<DeviceConvFwdNoOpPtr> Get(std::size_t num_dim_spatial)
    {
        std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
        if(num_dim_spatial == 2)
        {
136
            ck::tensor_operation::device::instance::
137
138
139
140
141
142
143
144
145
146
147
148
149
150
                add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
        }
        return conv_ptrs;
    }
};

template <>
struct ConvolutionNDFwdInstances<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t>
{
    static std::vector<DeviceConvFwdNoOpPtr> Get(std::size_t num_dim_spatial)
    {
        std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
        if(num_dim_spatial == 2)
        {
151
            ck::tensor_operation::device::instance::
152
153
154
155
156
157
158
159
160
161
162
163
164
165
                add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
        }
        return conv_ptrs;
    }
};

template <>
struct ConvolutionNDFwdInstances<int8_t, int8_t, int8_t>
{
    static std::vector<DeviceConvFwdNoOpPtr> Get(std::size_t num_dim_spatial)
    {
        std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
        if(num_dim_spatial == 2)
        {
166
            ck::tensor_operation::device::instance::
167
168
169
170
171
172
                add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
        }
        return conv_ptrs;
    }
};

173
174
} // namespace conv
} // namespace test