cluster_descriptor.hpp 1.12 KB
Newer Older
Yang0001's avatar
Yang0001 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"

namespace ck {

template <typename Lengths,
          typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
__host__ __device__ constexpr auto make_cluster_descriptor(
    const Lengths& lengths,
    ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
{
    constexpr index_t ndim_low = Lengths::Size();

    const auto reordered_lengths = container_reorder_given_new2old(lengths, order);

    const auto low_lengths = generate_tuple(
        [&](auto idim_low) { return reordered_lengths[idim_low]; }, Number<ndim_low>{});

    const auto transform = make_merge_transform(low_lengths);

    constexpr auto low_dim_old_top_ids = ArrangeOrder{};

    constexpr auto up_dim_new_top_ids = Sequence<0>{};

    return make_single_stage_tensor_adaptor(
        make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
}

} // namespace ck