device_dummy_dynamic_transform.hpp 4.5 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "dummy_dynamic_transform.hpp"

template <class T,
          class InDesc,
          class WeiDesc,
          class OutDesc,
          class ConvStrides,
          class ConvDilations,
          class InLeftPads,
          class InRightPads>
void device_dummy_dynamic_transform(InDesc,
                                    const Tensor<T>& in_nchw,
                                    WeiDesc,
                                    const Tensor<T>& wei_kcyx,
                                    OutDesc,
                                    Tensor<T>& out_nkhw,
                                    ConvStrides,
                                    ConvDilations,
                                    InLeftPads,
                                    InRightPads,
                                    ck::index_t nrepeat)
{
    using namespace ck;

    using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;

    const auto in_nchw_desc = make_dynamic_native_tensor_descriptor(to_array(InDesc::GetLengths()),
                                                                    to_array(InDesc::GetStrides()));
    const auto wei_kcyx_desc = make_dynamic_native_tensor_descriptor(
        to_array(WeiDesc::GetLengths()), to_array(WeiDesc::GetStrides()));
    const auto out_nkhw_desc = make_dynamic_native_tensor_descriptor(
        to_array(OutDesc::GetLengths()), to_array(OutDesc::GetStrides()));

    const auto conv_strides   = to_array(ConvStrides{});
    const auto conv_dilations = to_array(ConvDilations{});
    const auto in_left_pads   = to_array(InLeftPads{});
    const auto in_right_pads  = to_array(InRightPads{});

    std::size_t data_sz = sizeof(T);
    DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
    DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
    DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());

    in_nchw_device_buf.ToDevice(in_nchw.mData.data());
    wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
    out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());

    constexpr index_t BlockSize = 256;
    constexpr index_t GridSize  = 1;

    printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);

    using dummy_transform = DummyDynamicTransform<BlockSize>;

    for(index_t i = 0; i < 5; ++i)
    {
        std::cout << "Start running " << nrepeat << " times..." << std::endl;

        KernelTimer timer;
        timer.Start();

        for(index_t j = 0; j < nrepeat; ++j)
        {
            launch_kernel(run_gridwise_operation<dummy_transform,
                                                 index_t* const,
                                                 index_t* const,
                                                 float* const,
                                                 const DynamicNativeTensorDescriptor<4>,
                                                 const DynamicNativeTensorDescriptor<4>,
                                                 const DynamicNativeTensorDescriptor<4>,
                                                 const Array<index_t, 2>,
                                                 const Array<index_t, 2>,
                                                 const Array<index_t, 2>,
                                                 const Array<index_t, 2>,
                                                 index_t,
                                                 index_t,
                                                 index_t,
                                                 index_t>,
                          dim3(GridSize),
                          dim3(BlockSize),
                          0,
                          0,
                          static_cast<index_t*>(in_nchw_device_buf.GetDeviceBuffer()),
                          static_cast<index_t*>(wei_kcyx_device_buf.GetDeviceBuffer()),
                          static_cast<float*>(out_nkhw_device_buf.GetDeviceBuffer()),
                          wei_kcyx_desc,
                          in_nchw_desc,
                          out_nkhw_desc,
                          conv_strides,
                          conv_dilations,
                          in_left_pads,
                          in_right_pads,
                          10,
                          10,
                          10,
                          10);
        }
    }

    out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}