copy.hpp 9.85 KB
Newer Older
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5

#pragma once

6
#include "ck/wrapper/utils/tensor_utils.hpp"
7

8
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
9
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
10
11
12
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
13
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
14

15
16
17
namespace ck {
namespace wrapper {

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
/**
 * \brief Perform optimized copy between two tensors partitions (threadwise copy).
 * Tensors must have the same size.
 *
 * \tparam DimAccessOrderTuple Tuple with dimension access order.
 * \tparam VectorDim Dimension for vectorized read and write.
 * \tparam ScalarPerVector Number of scalar per vectorized read and write.
 * \param src_tensor Source tensor.
 * \param dst_tensor Destination tensor.
 */
template <typename DimAccessOrderTuple,
          index_t VectorDim,
          index_t ScalarPerVector,
          typename SrcTensorType,
          typename DstTensorType>
__device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
{
    static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};

    const auto& in_grid_desc  = layout(src_tensor).GetUnrolledDescriptor();
    const auto& out_grid_desc = layout(dst_tensor).GetUnrolledDescriptor();

    using SrcShapeType         = remove_cvref_t<decltype(shape(src_tensor))>;
    constexpr index_t num_dims = SrcShapeType::Size();

    constexpr auto thread_slice_lengths =
        generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
    constexpr auto dim_access_order = generate_sequence_v2(
        [](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});

    if constexpr(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer)
    {
        // Perform a copy between DynamicBuffers
        auto transfer = ThreadwiseTensorSliceTransfer_v7<
            Tuple<typename SrcTensorType::TensorElementType>,
            Tuple<typename DstTensorType::TensorElementType>,
            decltype(tie(in_grid_desc)),
            decltype(tie(out_grid_desc)),
            tensor_operation::element_wise::PassThrough,
            Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
            decltype(thread_slice_lengths),
            decltype(dim_access_order),
            VectorDim,
            ScalarPerVector,
64
65
66
67
68
69
            Sequence<true>,
            Sequence<true>>{in_grid_desc,
                            make_tuple(src_tensor.GetMultiIdxOffsets()),
                            out_grid_desc,
                            make_tuple(dst_tensor.GetMultiIdxOffsets()),
                            tensor_operation::element_wise::PassThrough{}};
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

        transfer.Run(tie(in_grid_desc),
                     tie(src_tensor.GetBuffer()),
                     tie(out_grid_desc),
                     tie(dst_tensor.GetBuffer()));
    }
    else if constexpr(!SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer)
    {
        // Perform copy from StaticBuffer to DynamicBuffer
        const auto src_slice_origin_idxs =
            generate_tuple([&](auto) { return I0; }, Number<num_dims>{});

        auto transfer =
            ThreadwiseTensorSliceTransfer_v1r3<typename SrcTensorType::TensorElementType,
                                               typename DstTensorType::TensorElementType,
                                               remove_cvref_t<decltype(in_grid_desc)>,
                                               remove_cvref_t<decltype(out_grid_desc)>,
                                               tensor_operation::element_wise::PassThrough,
                                               decltype(thread_slice_lengths),
                                               decltype(dim_access_order),
                                               VectorDim,
                                               ScalarPerVector,
                                               InMemoryDataOperationEnum::Set,
                                               I1,
                                               true>{out_grid_desc,
                                                     dst_tensor.GetMultiIdxOffsets(),
                                                     tensor_operation::element_wise::PassThrough{}};

        transfer.Run(in_grid_desc,
                     src_slice_origin_idxs,
                     src_tensor.GetBuffer(),
                     out_grid_desc,
                     dst_tensor.GetBuffer());
    }
    else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer)
    {
        // Perform copy from DynamicBuffer to StaticBuffer
107
        const auto dst_slice_origin_idxs =
108
            generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
109
110
111
112
113
114
115
116
117
118
119
120
        auto transfer = ThreadwiseTensorSliceTransfer_v2<
            std::remove_const_t<typename SrcTensorType::TensorElementType>,
            std::remove_const_t<typename DstTensorType::TensorElementType>,
            remove_cvref_t<decltype(in_grid_desc)>,
            remove_cvref_t<decltype(out_grid_desc)>,
            decltype(thread_slice_lengths),
            decltype(dim_access_order),
            VectorDim,
            ScalarPerVector,
            I1,
            false,
            false>{in_grid_desc, src_tensor.GetMultiIdxOffsets()};
121
122
123
124

        transfer.Run(in_grid_desc,
                     src_tensor.GetBuffer(),
                     out_grid_desc,
125
                     dst_slice_origin_idxs,
126
127
128
129
130
                     dst_tensor.GetBuffer());
    }
    else
    {
        // Perform copy between StaticBuffers
131
        static_for<0, SrcShapeType::Size(), 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
132
133
134
    }
}

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
/**
 * \brief Perform generic copy between two tensors partitions (threadwise copy).
 *  Tensors must have the same size.
 *
 * \param src_tensor Source tensor.
 * \param dst_tensor Destination tensor.
 */
template <typename SrcTensorType, typename DstTensorType>
__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
{
    // Generate default params
    using SrcShapeType         = remove_cvref_t<decltype(shape(src_tensor))>;
    constexpr index_t num_dims = SrcShapeType::Size();
    // Incrementing dims 0, 1, 2 ... num_dims - 1
    constexpr auto dim_access_order_tuple =
        generate_tuple([](auto i) { return Number<i>{}; }, Number<num_dims>{});
    constexpr index_t vector_dim        = num_dims - 1;
    constexpr index_t scalar_per_vector = 1;
    copy<decltype(dim_access_order_tuple), vector_dim, scalar_per_vector>(src_tensor, dst_tensor);
}

/**
 * \brief Perform optimized blockwise copy between two tensors. Tensors must have the
 *  same size.
 *
 * \note At now Vgpr and Sgpr are not supported.
 *
 * \tparam DimAccessOrderTuple Tuple with dimension access order.
 * \tparam VectorDim Dimension for vectorize read and write.
 * \tparam ScalarPerVector Number of scalar per vectorize read and write.
 * \param src_tensor Source tensor.
 * \param dst_tensor Destination tensor.
 * \param thread_layout Thread layout per each dimension for copy.
 */
template <typename DimAccessOrderTuple,
          index_t VectorDim,
          index_t ScalarPerVector,
          typename SrcTensorType,
          typename DstTensorType,
174
175
176
177
178
179
          typename ThreadShape,
          typename ThreadUnrolledDesc>
__device__ void
blockwise_copy(const SrcTensorType& src_tensor,
               DstTensorType& dst_tensor,
               [[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout)
180
181
182
183
184
185
186
187
188
189
190
191
{
    static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer);
    static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);

    const auto& in_grid_desc  = layout(src_tensor).GetUnrolledDescriptor();
    const auto& out_grid_desc = layout(dst_tensor).GetUnrolledDescriptor();

    using SrcShapeType         = remove_cvref_t<decltype(shape(src_tensor))>;
    constexpr index_t num_dims = SrcShapeType::Size();

    constexpr auto tile_lengths_seq =
        generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
192
193
    constexpr auto thread_layout_seq =
        generate_sequence_v2([](auto I) { return size<I>(ThreadShape{}); }, Number<num_dims>{});
194
195
196
    constexpr auto dim_access_order = generate_sequence_v2(
        [](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});

197
    using ThisThreadBlock = ThisThreadBlock<size(ThreadShape{})>;
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

    // Perform copy between DynamicBuffers
    auto transfer = ThreadGroupTensorSliceTransfer_v7<
        ThisThreadBlock,
        Tuple<typename SrcTensorType::TensorElementType>,
        Tuple<typename DstTensorType::TensorElementType>,
        decltype(tie(in_grid_desc)),
        decltype(tie(out_grid_desc)),
        tensor_operation::element_wise::PassThrough,
        Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
        std::remove_const_t<decltype(tile_lengths_seq)>,
        std::remove_const_t<decltype(thread_layout_seq)>,
        std::remove_const_t<decltype(dim_access_order)>,
        std::remove_const_t<decltype(dim_access_order)>,
        VectorDim,
        ScalarPerVector,
        Sequence<true>,
        Sequence<true>>{in_grid_desc,
                        make_tuple(src_tensor.GetMultiIdxOffsets()),
                        out_grid_desc,
                        make_tuple(dst_tensor.GetMultiIdxOffsets()),
                        tensor_operation::element_wise::PassThrough{}};

    transfer.Run(tie(in_grid_desc),
                 tie(src_tensor.GetBuffer()),
                 tie(out_grid_desc),
                 tie(dst_tensor.GetBuffer()));
}

227
228
} // namespace wrapper
} // namespace ck