"examples/FP16_Optimizer_simple/distributed_apex/run.sh" did not exist on "83acda92a800b7378d2e1b19c82d2ba9cae62d86"
space_filling_curve.cpp 5.98 KB
Newer Older
Jianfeng Yan's avatar
Jianfeng Yan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>

#include "tensor_space_filling_curve.hpp"

using namespace ck;

void traverse_using_space_filling_curve();

int main(int argc, char** argv)
{
    (void)argc;
    (void)argv;

    {
        traverse_using_space_filling_curve();
        auto err = hipDeviceSynchronize();
        (void)err;
        assert(err == hipSuccess);
    }
    return 0;
}

void traverse_using_space_filling_curve()
{
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};

    using TensorLengths     = Sequence<4, 10, 9>;
    using DimAccessOrder    = Sequence<2, 0, 1>;
    using ScalarsPerAccess  = Sequence<1, 2, 3>;
    using SpaceFillingCurve = SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess>;

    constexpr auto expected = make_tuple(make_tuple(0, 0, 0),
                                         make_tuple(0, 2, 0),
                                         make_tuple(0, 4, 0),
                                         make_tuple(0, 6, 0),
                                         make_tuple(0, 8, 0),
                                         make_tuple(1, 8, 0),
                                         make_tuple(1, 6, 0),
                                         make_tuple(1, 4, 0),
                                         make_tuple(1, 2, 0),
                                         make_tuple(1, 0, 0),
                                         make_tuple(2, 0, 0),
                                         make_tuple(2, 2, 0),
                                         make_tuple(2, 4, 0),
                                         make_tuple(2, 6, 0),
                                         make_tuple(2, 8, 0),
                                         make_tuple(3, 8, 0),
                                         make_tuple(3, 6, 0),
                                         make_tuple(3, 4, 0),
                                         make_tuple(3, 2, 0),
                                         make_tuple(3, 0, 0),
                                         make_tuple(3, 0, 3),
                                         make_tuple(3, 2, 3),
                                         make_tuple(3, 4, 3),
                                         make_tuple(3, 6, 3),
                                         make_tuple(3, 8, 3),
                                         make_tuple(2, 8, 3),
                                         make_tuple(2, 6, 3),
                                         make_tuple(2, 4, 3),
                                         make_tuple(2, 2, 3),
                                         make_tuple(2, 0, 3),
                                         make_tuple(1, 0, 3),
                                         make_tuple(1, 2, 3),
                                         make_tuple(1, 4, 3),
                                         make_tuple(1, 6, 3),
                                         make_tuple(1, 8, 3),
                                         make_tuple(0, 8, 3),
                                         make_tuple(0, 6, 3),
                                         make_tuple(0, 4, 3),
                                         make_tuple(0, 2, 3),
                                         make_tuple(0, 0, 3),
                                         make_tuple(0, 0, 6),
                                         make_tuple(0, 2, 6),
                                         make_tuple(0, 4, 6),
                                         make_tuple(0, 6, 6),
                                         make_tuple(0, 8, 6),
                                         make_tuple(1, 8, 6),
                                         make_tuple(1, 6, 6),
                                         make_tuple(1, 4, 6),
                                         make_tuple(1, 2, 6),
                                         make_tuple(1, 0, 6),
                                         make_tuple(2, 0, 6),
                                         make_tuple(2, 2, 6),
                                         make_tuple(2, 4, 6),
                                         make_tuple(2, 6, 6),
                                         make_tuple(2, 8, 6),
                                         make_tuple(3, 8, 6),
                                         make_tuple(3, 6, 6),
                                         make_tuple(3, 4, 6),
                                         make_tuple(3, 2, 6),
                                         make_tuple(3, 0, 6));

    constexpr index_t num_accesses = SpaceFillingCurve::GetNumOfAccess();

    static_assert(num_accesses == reduce_on_sequence(TensorLengths{} / ScalarsPerAccess{},
                                                     math::multiplies{},
                                                     Number<1>{}));

    static_for<1, num_accesses, 1>{}([&](auto i) {
        constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i);

        static_assert(idx_curr[I0] == expected[i][I0]);
        static_assert(idx_curr[I1] == expected[i][I1]);
        static_assert(idx_curr[I2] == expected[i][I2]);

        constexpr auto backward_step = SpaceFillingCurve::GetBackwardStep(i);
        constexpr auto expected_step = expected[i - I1] - expected[i];
        static_assert(backward_step[I0] == expected_step[I0]);
        static_assert(backward_step[I1] == expected_step[I1]);
        static_assert(backward_step[I2] == expected_step[I2]);
    });

    static_for<0, num_accesses - 1, 1>{}([&](auto i) {
        constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i);

        static_assert(idx_curr[I0] == expected[i][I0]);
        static_assert(idx_curr[I1] == expected[i][I1]);
        static_assert(idx_curr[I2] == expected[i][I2]);

        constexpr auto forward_step  = SpaceFillingCurve::GetForwardStep(i);
        constexpr auto expected_step = expected[i + I1] - expected[i];
        static_assert(forward_step[I0] == expected_step[I0]);
        static_assert(forward_step[I1] == expected_step[I1]);
        static_assert(forward_step[I2] == expected_step[I2]);
    });
}