copy.hpp 9.92 KB
Newer Older
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5

#pragma once

6
#include "ck/wrapper/utils/tensor_utils.hpp"
7

8
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
9
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
10
11
12
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
13
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
14

15
16
// Disable from doxygen docs generation
/// @cond INTERNAL
17
18
namespace ck {
namespace wrapper {
19
/// @endcond
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
/**
 * \brief Perform optimized copy between two tensors partitions (threadwise copy).
 * Tensors must have the same size.
 *
 * \tparam DimAccessOrderTuple Tuple with dimension access order.
 * \tparam VectorDim Dimension for vectorized read and write.
 * \tparam ScalarPerVector Number of scalar per vectorized read and write.
 * \param src_tensor Source tensor.
 * \param dst_tensor Destination tensor.
 */
template <typename DimAccessOrderTuple,
          index_t VectorDim,
          index_t ScalarPerVector,
          typename SrcTensorType,
          typename DstTensorType>
__device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
{
    static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};

    const auto& in_grid_desc  = layout(src_tensor).GetUnrolledDescriptor();
    const auto& out_grid_desc = layout(dst_tensor).GetUnrolledDescriptor();

    using SrcShapeType         = remove_cvref_t<decltype(shape(src_tensor))>;
    constexpr index_t num_dims = SrcShapeType::Size();

    constexpr auto thread_slice_lengths =
        generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
    constexpr auto dim_access_order = generate_sequence_v2(
        [](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});

    if constexpr(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer)
    {
        // Perform a copy between DynamicBuffers
        auto transfer = ThreadwiseTensorSliceTransfer_v7<
            Tuple<typename SrcTensorType::TensorElementType>,
            Tuple<typename DstTensorType::TensorElementType>,
            decltype(tie(in_grid_desc)),
            decltype(tie(out_grid_desc)),
            tensor_operation::element_wise::PassThrough,
            Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
            decltype(thread_slice_lengths),
            decltype(dim_access_order),
            VectorDim,
            ScalarPerVector,
67
68
69
70
71
72
            Sequence<true>,
            Sequence<true>>{in_grid_desc,
                            make_tuple(src_tensor.GetMultiIdxOffsets()),
                            out_grid_desc,
                            make_tuple(dst_tensor.GetMultiIdxOffsets()),
                            tensor_operation::element_wise::PassThrough{}};
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

        transfer.Run(tie(in_grid_desc),
                     tie(src_tensor.GetBuffer()),
                     tie(out_grid_desc),
                     tie(dst_tensor.GetBuffer()));
    }
    else if constexpr(!SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer)
    {
        // Perform copy from StaticBuffer to DynamicBuffer
        const auto src_slice_origin_idxs =
            generate_tuple([&](auto) { return I0; }, Number<num_dims>{});

        auto transfer =
            ThreadwiseTensorSliceTransfer_v1r3<typename SrcTensorType::TensorElementType,
                                               typename DstTensorType::TensorElementType,
                                               remove_cvref_t<decltype(in_grid_desc)>,
                                               remove_cvref_t<decltype(out_grid_desc)>,
                                               tensor_operation::element_wise::PassThrough,
                                               decltype(thread_slice_lengths),
                                               decltype(dim_access_order),
                                               VectorDim,
                                               ScalarPerVector,
                                               InMemoryDataOperationEnum::Set,
                                               I1,
                                               true>{out_grid_desc,
                                                     dst_tensor.GetMultiIdxOffsets(),
                                                     tensor_operation::element_wise::PassThrough{}};

        transfer.Run(in_grid_desc,
                     src_slice_origin_idxs,
                     src_tensor.GetBuffer(),
                     out_grid_desc,
                     dst_tensor.GetBuffer());
    }
    else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer)
    {
        // Perform copy from DynamicBuffer to StaticBuffer
110
        const auto dst_slice_origin_idxs =
111
            generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
112
113
114
115
116
117
118
119
120
121
122
123
        auto transfer = ThreadwiseTensorSliceTransfer_v2<
            std::remove_const_t<typename SrcTensorType::TensorElementType>,
            std::remove_const_t<typename DstTensorType::TensorElementType>,
            remove_cvref_t<decltype(in_grid_desc)>,
            remove_cvref_t<decltype(out_grid_desc)>,
            decltype(thread_slice_lengths),
            decltype(dim_access_order),
            VectorDim,
            ScalarPerVector,
            I1,
            false,
            false>{in_grid_desc, src_tensor.GetMultiIdxOffsets()};
124
125
126
127

        transfer.Run(in_grid_desc,
                     src_tensor.GetBuffer(),
                     out_grid_desc,
128
                     dst_slice_origin_idxs,
129
130
131
132
133
                     dst_tensor.GetBuffer());
    }
    else
    {
        // Perform copy between StaticBuffers
134
        static_for<0, SrcShapeType::Size(), 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
135
136
137
    }
}

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
/**
 * \brief Perform generic copy between two tensors partitions (threadwise copy).
 *  Tensors must have the same size.
 *
 * \param src_tensor Source tensor.
 * \param dst_tensor Destination tensor.
 */
template <typename SrcTensorType, typename DstTensorType>
__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
{
    // Generate default params
    using SrcShapeType         = remove_cvref_t<decltype(shape(src_tensor))>;
    constexpr index_t num_dims = SrcShapeType::Size();
    // Incrementing dims 0, 1, 2 ... num_dims - 1
    constexpr auto dim_access_order_tuple =
        generate_tuple([](auto i) { return Number<i>{}; }, Number<num_dims>{});
    constexpr index_t vector_dim        = num_dims - 1;
    constexpr index_t scalar_per_vector = 1;
    copy<decltype(dim_access_order_tuple), vector_dim, scalar_per_vector>(src_tensor, dst_tensor);
}

/**
 * \brief Perform optimized blockwise copy between two tensors. Tensors must have the
 *  same size.
 *
 * \note At now Vgpr and Sgpr are not supported.
 *
 * \tparam DimAccessOrderTuple Tuple with dimension access order.
 * \tparam VectorDim Dimension for vectorize read and write.
 * \tparam ScalarPerVector Number of scalar per vectorize read and write.
 * \param src_tensor Source tensor.
 * \param dst_tensor Destination tensor.
 * \param thread_layout Thread layout per each dimension for copy.
 */
template <typename DimAccessOrderTuple,
          index_t VectorDim,
          index_t ScalarPerVector,
          typename SrcTensorType,
          typename DstTensorType,
177
178
179
180
181
182
          typename ThreadShape,
          typename ThreadUnrolledDesc>
__device__ void
blockwise_copy(const SrcTensorType& src_tensor,
               DstTensorType& dst_tensor,
               [[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout)
183
184
185
186
187
188
189
190
191
192
193
194
{
    static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer);
    static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);

    const auto& in_grid_desc  = layout(src_tensor).GetUnrolledDescriptor();
    const auto& out_grid_desc = layout(dst_tensor).GetUnrolledDescriptor();

    using SrcShapeType         = remove_cvref_t<decltype(shape(src_tensor))>;
    constexpr index_t num_dims = SrcShapeType::Size();

    constexpr auto tile_lengths_seq =
        generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
195
196
    constexpr auto thread_layout_seq =
        generate_sequence_v2([](auto I) { return size<I>(ThreadShape{}); }, Number<num_dims>{});
197
198
199
    constexpr auto dim_access_order = generate_sequence_v2(
        [](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});

200
    using ThisThreadBlock = ThisThreadBlock<size(ThreadShape{})>;
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229

    // Perform copy between DynamicBuffers
    auto transfer = ThreadGroupTensorSliceTransfer_v7<
        ThisThreadBlock,
        Tuple<typename SrcTensorType::TensorElementType>,
        Tuple<typename DstTensorType::TensorElementType>,
        decltype(tie(in_grid_desc)),
        decltype(tie(out_grid_desc)),
        tensor_operation::element_wise::PassThrough,
        Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
        std::remove_const_t<decltype(tile_lengths_seq)>,
        std::remove_const_t<decltype(thread_layout_seq)>,
        std::remove_const_t<decltype(dim_access_order)>,
        std::remove_const_t<decltype(dim_access_order)>,
        VectorDim,
        ScalarPerVector,
        Sequence<true>,
        Sequence<true>>{in_grid_desc,
                        make_tuple(src_tensor.GetMultiIdxOffsets()),
                        out_grid_desc,
                        make_tuple(dst_tensor.GetMultiIdxOffsets()),
                        tensor_operation::element_wise::PassThrough{}};

    transfer.Run(tie(in_grid_desc),
                 tie(src_tensor.GetBuffer()),
                 tie(out_grid_desc),
                 tie(dst_tensor.GetBuffer()));
}

230
231
} // namespace wrapper
} // namespace ck