space_filling_curve.cpp 8.88 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

Jianfeng Yan's avatar
Jianfeng Yan committed
4
5
6
7
8
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>

Chao Liu's avatar
Chao Liu committed
9
10
11
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
Jianfeng Yan's avatar
Jianfeng Yan committed
12
13
14

using namespace ck;

15
16
void traverse_using_space_filling_curve_linear();
void traverse_using_space_filling_curve_snakecurved();
Jianfeng Yan's avatar
Jianfeng Yan committed
17
18
19
20
21
22

int main(int argc, char** argv)
{
    (void)argc;
    (void)argv;

23
24
    traverse_using_space_filling_curve_linear();
    traverse_using_space_filling_curve_snakecurved();
25

Jianfeng Yan's avatar
Jianfeng Yan committed
26
27
28
    return 0;
}

29
void traverse_using_space_filling_curve_linear()
Jianfeng Yan's avatar
Jianfeng Yan committed
30
31
32
33
34
{
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    using TensorLengths    = Sequence<3, 2, 2>;
    using DimAccessOrder   = Sequence<2, 0, 1>;
    using ScalarsPerAccess = Sequence<1, 1, 1>;
    using SpaceFillingCurve =
        SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess, false>;

    constexpr auto expected = make_tuple(make_tuple(0, 0, 0),
                                         make_tuple(0, 1, 0),
                                         make_tuple(1, 0, 0),
                                         make_tuple(1, 1, 0),
                                         make_tuple(2, 0, 0),
                                         make_tuple(2, 1, 0),
                                         make_tuple(0, 0, 1),
                                         make_tuple(0, 1, 1),
                                         make_tuple(1, 0, 1),
                                         make_tuple(1, 1, 1),
                                         make_tuple(2, 0, 1),
                                         make_tuple(2, 1, 1));

    constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();

    static_assert(num_access == reduce_on_sequence(TensorLengths{} / ScalarsPerAccess{},
                                                   math::multiplies{},
                                                   Number<1>{}));

    static_for<1, num_access, 1>{}([&](auto i) {
        constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i);

        static_assert(idx_curr[I0] == expected[i][I0]);
        static_assert(idx_curr[I1] == expected[i][I1]);
        static_assert(idx_curr[I2] == expected[i][I2]);

        constexpr auto backward_step = SpaceFillingCurve::GetBackwardStep(i);
        constexpr auto expected_step = expected[i - I1] - expected[i];
        static_assert(backward_step[I0] == expected_step[I0]);
        static_assert(backward_step[I1] == expected_step[I1]);
        static_assert(backward_step[I2] == expected_step[I2]);
    });

    static_for<0, num_access - 1, 1>{}([&](auto i) {
        constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i);

        static_assert(idx_curr[I0] == expected[i][I0]);
        static_assert(idx_curr[I1] == expected[i][I1]);
        static_assert(idx_curr[I2] == expected[i][I2]);

        constexpr auto forward_step  = SpaceFillingCurve::GetForwardStep(i);
        constexpr auto expected_step = expected[i + I1] - expected[i];
        static_assert(forward_step[I0] == expected_step[I0]);
        static_assert(forward_step[I1] == expected_step[I1]);
        static_assert(forward_step[I2] == expected_step[I2]);
    });
}

void traverse_using_space_filling_curve_snakecurved()
{
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};

    using TensorLengths    = Sequence<16, 10, 9>;
    using DimAccessOrder   = Sequence<2, 0, 1>;
    using ScalarsPerAccess = Sequence<4, 2, 3>;
    using SpaceFillingCurve =
        SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess, true>;
Jianfeng Yan's avatar
Jianfeng Yan committed
100
101
102
103
104
105

    constexpr auto expected = make_tuple(make_tuple(0, 0, 0),
                                         make_tuple(0, 2, 0),
                                         make_tuple(0, 4, 0),
                                         make_tuple(0, 6, 0),
                                         make_tuple(0, 8, 0),
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
                                         make_tuple(4, 8, 0),
                                         make_tuple(4, 6, 0),
                                         make_tuple(4, 4, 0),
                                         make_tuple(4, 2, 0),
                                         make_tuple(4, 0, 0),
                                         make_tuple(8, 0, 0),
                                         make_tuple(8, 2, 0),
                                         make_tuple(8, 4, 0),
                                         make_tuple(8, 6, 0),
                                         make_tuple(8, 8, 0),
                                         make_tuple(12, 8, 0),
                                         make_tuple(12, 6, 0),
                                         make_tuple(12, 4, 0),
                                         make_tuple(12, 2, 0),
                                         make_tuple(12, 0, 0),
                                         make_tuple(12, 0, 3),
                                         make_tuple(12, 2, 3),
                                         make_tuple(12, 4, 3),
                                         make_tuple(12, 6, 3),
                                         make_tuple(12, 8, 3),
                                         make_tuple(8, 8, 3),
                                         make_tuple(8, 6, 3),
                                         make_tuple(8, 4, 3),
                                         make_tuple(8, 2, 3),
                                         make_tuple(8, 0, 3),
                                         make_tuple(4, 0, 3),
                                         make_tuple(4, 2, 3),
                                         make_tuple(4, 4, 3),
                                         make_tuple(4, 6, 3),
                                         make_tuple(4, 8, 3),
Jianfeng Yan's avatar
Jianfeng Yan committed
136
137
138
139
140
141
142
143
144
145
                                         make_tuple(0, 8, 3),
                                         make_tuple(0, 6, 3),
                                         make_tuple(0, 4, 3),
                                         make_tuple(0, 2, 3),
                                         make_tuple(0, 0, 3),
                                         make_tuple(0, 0, 6),
                                         make_tuple(0, 2, 6),
                                         make_tuple(0, 4, 6),
                                         make_tuple(0, 6, 6),
                                         make_tuple(0, 8, 6),
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
                                         make_tuple(4, 8, 6),
                                         make_tuple(4, 6, 6),
                                         make_tuple(4, 4, 6),
                                         make_tuple(4, 2, 6),
                                         make_tuple(4, 0, 6),
                                         make_tuple(8, 0, 6),
                                         make_tuple(8, 2, 6),
                                         make_tuple(8, 4, 6),
                                         make_tuple(8, 6, 6),
                                         make_tuple(8, 8, 6),
                                         make_tuple(12, 8, 6),
                                         make_tuple(12, 6, 6),
                                         make_tuple(12, 4, 6),
                                         make_tuple(12, 2, 6),
                                         make_tuple(12, 0, 6));
Jianfeng Yan's avatar
Jianfeng Yan committed
161

162
    constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
Jianfeng Yan's avatar
Jianfeng Yan committed
163

164
165
166
    static_assert(num_access == reduce_on_sequence(TensorLengths{} / ScalarsPerAccess{},
                                                   math::multiplies{},
                                                   Number<1>{}));
Jianfeng Yan's avatar
Jianfeng Yan committed
167

168
    static_for<1, num_access, 1>{}([&](auto i) {
Jianfeng Yan's avatar
Jianfeng Yan committed
169
170
171
172
173
174
175
176
177
178
179
180
181
        constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i);

        static_assert(idx_curr[I0] == expected[i][I0]);
        static_assert(idx_curr[I1] == expected[i][I1]);
        static_assert(idx_curr[I2] == expected[i][I2]);

        constexpr auto backward_step = SpaceFillingCurve::GetBackwardStep(i);
        constexpr auto expected_step = expected[i - I1] - expected[i];
        static_assert(backward_step[I0] == expected_step[I0]);
        static_assert(backward_step[I1] == expected_step[I1]);
        static_assert(backward_step[I2] == expected_step[I2]);
    });

182
    static_for<0, num_access - 1, 1>{}([&](auto i) {
Jianfeng Yan's avatar
Jianfeng Yan committed
183
184
185
186
187
188
189
190
191
192
193
194
195
        constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i);

        static_assert(idx_curr[I0] == expected[i][I0]);
        static_assert(idx_curr[I1] == expected[i][I1]);
        static_assert(idx_curr[I2] == expected[i][I2]);

        constexpr auto forward_step  = SpaceFillingCurve::GetForwardStep(i);
        constexpr auto expected_step = expected[i + I1] - expected[i];
        static_assert(forward_step[I0] == expected_step[I0]);
        static_assert(forward_step[I1] == expected_step[I1]);
        static_assert(forward_step[I2] == expected_step[I2]);
    });
}