threadwise_nd_tensor_op.hip.hpp 2.62 KB
Newer Older
1
2
3
4
#pragma once
#include "ConstantTensorDescriptor.hip.hpp"

// need to assume src and dst is aligned
Chao Liu's avatar
Chao Liu committed
5
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
6
__device__ void threadwise_nd_tensor_copy(SrcDesc,
7
8
9
10
11
12
                                          const Float* __restrict__ p_src,
                                          DstDesc,
                                          Float* __restrict__ p_dst,
                                          SrcOpLengths,
                                          Number<DataPerRead>)
{
Chao Liu's avatar
Chao Liu committed
13
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
14

15
    constexpr index_t nDim = SrcOpLengths::GetSize();
16

17
18
    static_assert(SrcDesc{}.GetDimension() == nDim && DstDesc{}.GetDimension() == nDim,
                  "wrong! dimension not consistent");
19
20
21
22
23

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});

24
25
#if 0
    if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
26
    {
27
28
29
        print_ConstantTensorDescriptor(src_desc, "src_desc");
        print_ConstantTensorDescriptor(dst_desc, "dst_desc");
        print_ConstantTensorDescriptor(ref_desc, "ref_desc");
30
    }
31
#endif
32

33
34
35
    static_assert(DataPerRead == 1 || (SrcDesc{}.GetStride(Number<nDim - 1>{}) == 1 &&
                                       DstDesc{}.GetStride(Number<nDim - 1>{}) == 1),
                  "wrong! only support stride[nDim-1] == 1!\n");
36
37
38
39

    static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                  "wrong! only support DataPerRead == 1, 2 or 4!\n");

40
41
42
43
    static_assert(
        SrcDesc{}.GetStride(Number<nDim - 2>{}) % DataPerRead == 0 &&
            DstDesc{}.GetStride(Number<nDim - 2>{}) % DataPerRead == 0,
        "wrong! src and dst stride[nDim-2] should be multiple of DataPerRead to keep alignment");
44

45
    constexpr index_t L_Back = SrcOpLengths{}.Back();
46

47
48
    static_assert(L_Back % DataPerRead == 0,
                  "wrong! lengths[nDim-1] should be evenly divided by DataPerRead");
49

50
    constexpr index_t nRead = L_Back / DataPerRead;
51

52
53
54
    static_ford<decltype(ref_desc.GetLengths().PopBack())>{}([=](auto Ids) {
        static_for<0, nRead, 1>{}([=](auto IRead) {
            constexpr auto multi_id = decltype(Ids){}.PushBack(Number<IRead.Get() * DataPerRead>{});
55

56
            const index_t src_index = src_desc.Get1dIndex(multi_id);
57

58
            const index_t dst_index = dst_desc.Get1dIndex(multi_id);
59

60
61
62
63
            *(reinterpret_cast<vector_t*>(&p_dst[dst_index])) =
                *(reinterpret_cast<const vector_t*>(&p_src[src_index]));
        });
    });
64
}