threadwise_2d_tensor_op.hip.hpp 4.29 KB
Newer Older
1
#pragma once
2
#include "ConstantTensorDescriptor.hip.hpp"
3

Chao Liu's avatar
Chao Liu committed
4
template <class Float, class Desc, class F>
Chao Liu's avatar
Chao Liu committed
5
__device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __restrict__ p, F f)
6
{
Chao Liu's avatar
Chao Liu committed
7
8
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
9

Chao Liu's avatar
Chao Liu committed
10
    constexpr auto desc = Desc{};
11
12

#if 0
13
    if(get_thread_local_1d_id() == 0)
14
    {
Chao Liu's avatar
Chao Liu committed
15
        print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_op_unary: ");
16
17
18
    }
#endif

Chao Liu's avatar
Chao Liu committed
19
    for(index_t did0 = 0; did0 < desc.GetLength(I0); ++did0)
Chao Liu's avatar
Chao Liu committed
20
    {
Chao Liu's avatar
Chao Liu committed
21
        for(index_t did1 = 0; did1 < desc.GetLength(I1); ++did1)
Chao Liu's avatar
Chao Liu committed
22
        {
Chao Liu's avatar
Chao Liu committed
23
            const index_t dindex = desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
24
25

            f(p[dindex]);
Chao Liu's avatar
Chao Liu committed
26
27
28
29
        }
    }
}

30
31
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
32
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, class MapDst2Src, class F>
Chao Liu's avatar
Chao Liu committed
33
__device__ void threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
Chao Liu's avatar
Chao Liu committed
34
35
36
37
38
    SrcDesc,
    Float* const __restrict__ p_src,
    DstDesc,
    Float* __restrict__ p_dst,
    SrcOpLengths,
39
    MapDst2Src,
Chao Liu's avatar
Chao Liu committed
40
    F f)
Chao Liu's avatar
Chao Liu committed
41
{
Chao Liu's avatar
Chao Liu committed
42
43
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
44

45
46
    constexpr index_t IR0 = MapDst2Src{}.Get(I0);
    constexpr index_t IR1 = MapDst2Src{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
47

48
49
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
50
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
Chao Liu's avatar
Chao Liu committed
51

Chao Liu's avatar
Chao Liu committed
52
    for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
Chao Liu's avatar
Chao Liu committed
53
    {
Chao Liu's avatar
Chao Liu committed
54
        for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
Chao Liu's avatar
Chao Liu committed
55
        {
Chao Liu's avatar
Chao Liu committed
56
            const index_t aindex = src_desc.Get1dIndex(did0, did1);
57

Chao Liu's avatar
Chao Liu committed
58
            const index_t did[2] = {did0, did1};
Chao Liu's avatar
Chao Liu committed
59

Chao Liu's avatar
Chao Liu committed
60
            const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
Chao Liu's avatar
Chao Liu committed
61

Chao Liu's avatar
Chao Liu committed
62
            f(p_src[aindex], p_dst[bindex]);
63
64
65
        }
    }
}
Chao Liu's avatar
Chao Liu committed
66

Chao Liu's avatar
Chao Liu committed
67
template <class Float, class Desc>
Chao Liu's avatar
Chao Liu committed
68
__device__ void threadwise_2d_tensor_set_zero(Desc, Float* __restrict__ p)
Chao Liu's avatar
Chao Liu committed
69
{
Chao Liu's avatar
Chao Liu committed
70
    auto f_set_zero = [](Float& v) { v = Float(0); };
Chao Liu's avatar
Chao Liu committed
71

Chao Liu's avatar
Chao Liu committed
72
    threadwise_2d_tensor_pointwise_operation_unary<Float, Desc, decltype(f_set_zero)>(
Chao Liu's avatar
Chao Liu committed
73
        Desc{}, p, f_set_zero);
Chao Liu's avatar
Chao Liu committed
74
}
Chao Liu's avatar
Chao Liu committed
75

76
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, class MapDst2Src>
Chao Liu's avatar
Chao Liu committed
77
__device__ void
Chao Liu's avatar
Chao Liu committed
78
threadwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
Chao Liu's avatar
Chao Liu committed
79
80
81
82
                                                      Float* const __restrict__ p_src,
                                                      DstDesc,
                                                      Float* __restrict__ p_dst,
                                                      SrcOpLengths,
83
                                                      MapDst2Src)
84
{
Chao Liu's avatar
Chao Liu committed
85
    auto f_copy = [](const Float& src, Float& dst) { dst = src; };
86

Chao Liu's avatar
Chao Liu committed
87
    threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
88
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
89
90
}

Chao Liu's avatar
Chao Liu committed
91
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
Chao Liu's avatar
Chao Liu committed
92
__device__ void threadwise_2d_tensor_copy(
Chao Liu's avatar
Chao Liu committed
93
    SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths)
Chao Liu's avatar
Chao Liu committed
94
{
Chao Liu's avatar
Chao Liu committed
95
    auto dst_from_src_reorder = Sequence<0, 1>{};
Chao Liu's avatar
Chao Liu committed
96

Chao Liu's avatar
Chao Liu committed
97
    threadwise_2d_tensor_copy_reorder_by_get_dst_from_src(
Chao Liu's avatar
Chao Liu committed
98
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
Chao Liu's avatar
Chao Liu committed
99
100
}

Chao Liu's avatar
Chao Liu committed
101
template <class Float, class Desc, class IDim, class NShift>
Chao Liu's avatar
Chao Liu committed
102
__device__ void threadwise_2d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
Chao Liu's avatar
Chao Liu committed
103
104
105
106
107
108
109
{
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};

    constexpr auto desc = Desc{};

#if 0
110
    if(get_thread_local_1d_id() == 0)
Chao Liu's avatar
Chao Liu committed
111
112
113
114
115
    {
        print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
    }
#endif

Chao Liu's avatar
Chao Liu committed
116
    constexpr index_t nshift = NShift::mValue;
Chao Liu's avatar
Chao Liu committed
117

Chao Liu's avatar
Chao Liu committed
118
    constexpr index_t did0_end =
Chao Liu's avatar
Chao Liu committed
119
        is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
Chao Liu's avatar
Chao Liu committed
120

Chao Liu's avatar
Chao Liu committed
121
    constexpr index_t did1_end =
Chao Liu's avatar
Chao Liu committed
122
        is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
Chao Liu's avatar
Chao Liu committed
123

Chao Liu's avatar
Chao Liu committed
124
    for(index_t did0 = 0; did0 < did0_end; ++did0)
Chao Liu's avatar
Chao Liu committed
125
    {
Chao Liu's avatar
Chao Liu committed
126
        for(index_t did1 = 0; did1 < did1_end; ++did1)
Chao Liu's avatar
Chao Liu committed
127
        {
Chao Liu's avatar
Chao Liu committed
128
            const index_t dindex = desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
129

Chao Liu's avatar
Chao Liu committed
130
            const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
Chao Liu's avatar
Chao Liu committed
131

Chao Liu's avatar
Chao Liu committed
132
            p[dindex] = p[sindex];
Chao Liu's avatar
Chao Liu committed
133
134
        }
    }
Chao Liu's avatar
Chao Liu committed
135
}