threadwise_winograd_transform.cuh 7.49 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#pragma once
#include "constant_tensor_descriptor.cuh"

template <class TFloat,
          class InTransThreadDesc,  //{NPerThread, CPerThread, InTileSizeH, InTileSizeW}
          class WeiTransThreadDesc, //{KPerThread, CPerThread, InTileSizeH, InTileSizeW}
          class OutTransThreadDesc, //{NPerThread, KPerThread, InTileSizeH, InTileSizeW}
          unsigned InTileSizeH,
          unsigned InTileSizeW,
          unsigned S,
          unsigned R,
          unsigned OutTileSizeH,
          unsigned OutTileSizeW>
__device__ void
threadwise_winograd_calculate_transformed_output(InTransThreadDesc,
                                                 TFloat* const __restrict__ p_in_transform_thread,
                                                 WeiTransThreadDesc,
                                                 TFloat* const __restrict__ p_wei_transform_thread,
                                                 OutTransThreadDesc,
                                                 TFloat* __restrict__ p_out_transform_thread)
{
    constexpr auto I0 = Index<0>{};
    constexpr auto I1 = Index<1>{};
    constexpr auto I2 = Index<2>{};
    constexpr auto I3 = Index<3>{};

    constexpr auto in_transform_thread_desc  = InTransThreadDesc{};
    constexpr auto wei_transform_thread_desc = WeiTransThreadDesc{};
    constexpr auto out_transform_thread_desc = OutTransThreadDesc{};

    for(unsigned n = 0; n < out_transform_thread_desc.GetLength(I0); ++n)
    {
        for(unsigned k = 0; k < out_transform_thread_desc.GetLength(I1); ++k)
        {
            for(unsigned h = 0; h < out_transform_thread_desc.GetLength(I2); ++h)
            {
                for(unsigned w = 0; w < out_transform_thread_desc.GetLength(I3); ++w)
                {
                    for(unsigned c = 0; c < wei_transform_thread_desc.GetLength(I1); ++c)
                    {
                        const unsigned in_index  = in_transform_thread_desc.Get1dIndex(n, c, h, w);
                        const unsigned wei_index = wei_transform_thread_desc.Get1dIndex(k, c, h, w);
                        const unsigned out_index = out_transform_thread_desc.Get1dIndex(n, k, h, w);

                        p_out_transform_thread[out_index] +=
                            p_wei_transform_thread[wei_index] * p_in_transform_thread[in_index];
                    }
                }
            }
        }
    }
}

template <class TFloat,
          class OutTransThreadDesc, //{NPerThread, KPerThread,  InTileSizeH,  InTileSizeW}
          class OutThreadDesc,      //{NPerThread, CPerThread, OutTileSizeH, OutTileSizeW}
          unsigned InTileSizeH,
          unsigned InTileSizeW,
          unsigned S,
          unsigned R,
          unsigned OutTileSizeH,
          unsigned OutTileSizeW>
__device__ void
threadwise_winograd_reverse_transform_output(OutTransThreadDesc,
                                             TFloat* const __restrict__ p_out_transform_thread,
                                             OutThreadDesc,
                                             TFloat* __restrict__ p_out_thread)
{
    static_assert(InTileSizeH == 4, "wrong");
    static_assert(InTileSizeW == 4, "wrong");
    static_assert(S == 3, "wrong");
    static_assert(R == 3, "wrong");
    static_assert(OutTileSizeH == 2, "wrong");
    static_assert(OutTileSizeW == 2, "wrong");

    constexpr auto I0 = Index<0>{};
    constexpr auto I1 = Index<1>{};
    constexpr auto I2 = Index<2>{};
    constexpr auto I3 = Index<3>{};

    constexpr auto out_transform_thread_desc = OutTransThreadDesc{};
    constexpr auto out_thread_desc           = OutThreadDesc{};

    static_assert(InTileSizeH == out_transform_thread_desc.GetLength(I2), "wrong");
    static_assert(InTileSizeW == out_transform_thread_desc.GetLength(I3), "wrong");
    static_assert(OutTileSizeH == out_thread_desc.GetLength(I2), "wrong");
    static_assert(OutTileSizeW == out_thread_desc.GetLength(I3), "wrong");

    for(unsigned n = 0; n < out_thread_desc.GetLength(I0); ++n)
    {
        for(unsigned k = 0; k < out_thread_desc.GetLength(I1); ++k)
        {
            p_out_thread[out_thread_desc.Get1dIndex(n, k, 0, 0)] =
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 0, 0)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 0, 1)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 0, 2)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 0)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 1)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 2)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 0)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 1)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 2)];

            p_out_thread[out_thread_desc.Get1dIndex(n, k, 0, 1)] =
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 0, 1)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 0, 2)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 0, 3)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 1)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 2)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 3)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 1)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 2)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 3)];

            p_out_thread[out_thread_desc.Get1dIndex(n, k, 1, 0)] =
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 0)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 1)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 2)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 0)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 1)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 2)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 3, 0)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 3, 1)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 3, 2)];

            p_out_thread[out_thread_desc.Get1dIndex(n, k, 1, 1)] =
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 1)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 2)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 3)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 1)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 2)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 3)] -
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 3, 1)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 3, 2)] +
                p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 3, 3)];
        }
    }
}