sequence_helper.hpp 2.69 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

4
#pragma once
Chao Liu's avatar
Chao Liu committed
5

carlushuang's avatar
carlushuang committed
6
7
#include "ck/utility/sequence.hpp"
#include "ck/utility/array.hpp"
8
#include "ck/utility/tuple.hpp"
carlushuang's avatar
carlushuang committed
9
#include "ck/utility/macro_func_array_to_sequence.hpp"
Chao Liu's avatar
Chao Liu committed
10
11
12

namespace ck {

13
14
15
16
17
18
template <index_t... Is>
__host__ __device__ constexpr auto make_sequence(Number<Is>...)
{
    return Sequence<Is...>{};
}

Chao Liu's avatar
Chao Liu committed
19
// F() returns index_t
carlushuang's avatar
carlushuang committed
20
// F use default constructor, so F cannot be lambda function
Chao Liu's avatar
Chao Liu committed
21
22
23
24
25
26
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_sequence(F, Number<N>)
{
    return typename sequence_gen<N, F>::type{};
}

Chao Liu's avatar
Chao Liu committed
27
// F() returns Number<>
carlushuang's avatar
carlushuang committed
28
// F could be lambda function
29
30
31
32
33
34
35
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number<N>)
{
    return unpack([&f](auto&&... xs) { return make_sequence(f(xs)...); },
                  typename arithmetic_sequence_gen<0, N, 1>::type{});
}

36
37
38
39
40
41
template <index_t... Is>
__host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
{
    return Sequence<Is...>{};
}

carlushuang's avatar
carlushuang committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
namespace detail {
template <index_t h_idx, typename SeqSortedSamples, typename SeqRange>
struct sorted_sequence_histogram;

template <index_t h_idx, index_t x, index_t... xs, index_t r, index_t... rs>
struct sorted_sequence_histogram<h_idx, Sequence<x, xs...>, Sequence<r, rs...>>
{
    template <typename Histogram>
    constexpr auto operator()(Histogram& h)
    {
        if constexpr(x < r)
        {
            h.template At<h_idx>() += 1;
            sorted_sequence_histogram<h_idx, Sequence<xs...>, Sequence<r, rs...>>{}(h);
        }
        else
        {
            h.template At<h_idx + 1>() = 1;
            sorted_sequence_histogram<h_idx + 1, Sequence<xs...>, Sequence<rs...>>{}(h);
        }
    }
};

template <index_t h_idx, index_t x, index_t r, index_t... rs>
struct sorted_sequence_histogram<h_idx, Sequence<x>, Sequence<r, rs...>>
{
    template <typename Histogram>
    constexpr auto operator()(Histogram& h)
    {
        if constexpr(x < r)
        {
            h.template At<h_idx>() += 1;
        }
    }
};
} // namespace detail

// SeqSortedSamples: <0, 2, 3, 5, 7>, SeqRange: <0, 3, 6, 9> -> SeqHistogram : <2, 2, 1>
template <typename SeqSortedSamples, index_t r, index_t... rs>
constexpr auto histogram_sorted_sequence(SeqSortedSamples, Sequence<r, rs...>)
{
    constexpr auto bins      = sizeof...(rs); // or categories
    constexpr auto histogram = [&]() {
        Array<index_t, bins> h{0}; // make sure this can clear all element to zero
        detail::sorted_sequence_histogram<0, SeqSortedSamples, Sequence<rs...>>{}(h);
        return h;
    }();

    return TO_SEQUENCE(histogram, bins);
}

Chao Liu's avatar
Chao Liu committed
93
} // namespace ck