"vscode:/vscode.git/clone" did not exist on "1c54a541e93f44289a420e7584cc00102a1d00c7"
slice_tile.hpp 2.9 KB
Newer Older
carlushuang's avatar
carlushuang committed
1
2
3
4
5
6
7
8
9
10
11
12
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"

#include "ck/tile_program/tile/tile_distribution.hpp"
Chao Liu's avatar
Chao Liu committed
13
#include "ck/tile_program/tile/static_tile_distribution_helper.hpp"
carlushuang's avatar
carlushuang committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include "ck/tile_program/tile/tile_window.hpp"
#include "ck/tile_program/tile/static_distributed_tensor.hpp"

namespace ck {
namespace tile_program {

template <typename StaticDistributedTensor_, index_t... SliceBegins, index_t... SliceEnds>
__host__ __device__ constexpr auto get_slice_tile(const StaticDistributedTensor_& tile,
                                                  Sequence<SliceBegins...> slice_begins,
                                                  Sequence<SliceEnds...> slice_ends)
{
    using Distribution = decltype(StaticDistributedTensor_::GetTileDistribution());
    using DataType     = typename StaticDistributedTensor_::DataType;

Chao Liu's avatar
Chao Liu committed
28
    constexpr auto sliced_dstr_yidx_ylen =
carlushuang's avatar
carlushuang committed
29
30
        detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends);

Chao Liu's avatar
Chao Liu committed
31
32
33
    constexpr auto sliced_dstr      = sliced_dstr_yidx_ylen.template At<0>();
    constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template At<1>();
    constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template At<2>();
carlushuang's avatar
carlushuang committed
34

Chao Liu's avatar
Chao Liu committed
35
    auto sliced_tensor = make_static_distributed_tensor<DataType>(sliced_dstr);
carlushuang's avatar
carlushuang committed
36

Chao Liu's avatar
Chao Liu committed
37
    sliced_tensor.GetThreadBuffer() = tile.GetYSlicedThreadData(sliced_y_origins, sliced_y_lengths);
carlushuang's avatar
carlushuang committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

    return sliced_tensor;
}

template <typename DstStaticDistributedTensor_,
          typename SrcStaticDistributedTensor_,
          index_t... SliceBegins,
          index_t... SliceEnds>
__host__ __device__ constexpr auto set_slice_tile(DstStaticDistributedTensor_& dst_tile,
                                                  const SrcStaticDistributedTensor_& src_tile,
                                                  Sequence<SliceBegins...> slice_begins,
                                                  Sequence<SliceEnds...> slice_ends)
{
    using DstDistribution = decltype(DstStaticDistributedTensor_::GetTileDistribution());

Chao Liu's avatar
Chao Liu committed
53
    constexpr auto sliced_dstr_yidx_ylen =
carlushuang's avatar
carlushuang committed
54
55
        detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends);

Chao Liu's avatar
Chao Liu committed
56
57
58
    constexpr auto sliced_dstr      = sliced_dstr_yidx_ylen.template At<0>();
    constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template At<1>();
    constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template At<2>();
carlushuang's avatar
carlushuang committed
59

Chao Liu's avatar
Chao Liu committed
60
    static_assert(is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
carlushuang's avatar
carlushuang committed
61
62
63
64
65
66

    dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.GetThreadBuffer());
}

} // namespace tile_program
} // namespace ck