block_swizzle_test.cpp 14.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
#include <stdio.h>
#include <string>
#include <algorithm>
#include <vector>
#include <limits>
#include "simple_args.h"

simple_args_t create_arg(int argc, char** argv)
{
    simple_args_t args;
    args.insert("m", "1024", "matrix m")
        .insert("n", "1024", "matrix n")
        .insert("k", "1024", "matrix k")
        .insert("m_per_block", "128", "m_per_block")
        .insert("n_per_block", "128", "n_per_block")
        .insert("k_per_block", "32", "k_per_block")
        .insert("num_cu", "104", "num cu")
        .insert("occupancy", "2", "occupancy")
        .parse(argc, argv);
    return args;
}

namespace impl {
template <typename T>
T integer_divide_ceil(T n, T d)
{
    return (n + d - 1) / d;
}

template <typename T>
T min(T a, T b)
{
    return a > b ? b : a;
}

template <typename T>
T max(T a, T b)
{
    return a > b ? a : b;
}

} // namespace impl

struct block_dispatcher_t
{
    public:
    uint32_t m_per_block;
    uint32_t n_per_block;
    uint32_t k_per_block;
    uint32_t num_cu;
    uint32_t occupancy;
    uint32_t m;
    uint32_t n;
    uint32_t k;

    //--------------------------------------

    uint32_t sk_num_blocks;
    uint32_t sk_num_big_blocks;
    uint32_t sk_total_iters;

    // uint32_t sk_num_blocks_per_tile;    // how many

    uint32_t dp_start_block_idx;
    uint32_t dp_iters_per_block;
    uint32_t dp_num_blocks;

    uint32_t k_iters_per_tile;
    uint32_t k_iters_per_big_block;
    //--------------------------------------

    static constexpr uint32_t min_k_iters_per_sk_block = 1;

    void dump()
    {
        printf("%dx%dx%d(%dx%dx%d), cu:%d, occ:%d, grids:%d, sk_num_big_blocks:%d, "
               "sk_num_blocks:%d, sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, "
               "dp_num_blocks:%d, k_iters_per_tile:%d, k_iters_per_big_block:%d\n",
               m,
               n,
               k,
               m_per_block,
               n_per_block,
               k_per_block,
               num_cu,
               occupancy,
               get_grid_dims_x(),
               sk_num_big_blocks,
               sk_num_blocks,
               sk_total_iters,
               dp_start_block_idx,
               dp_iters_per_block,
               dp_num_blocks,
               k_iters_per_tile,
               k_iters_per_big_block);
    }

    block_dispatcher_t(uint32_t m_per_block_,
                       uint32_t n_per_block_,
                       uint32_t k_per_block_,
                       uint32_t num_cu_,
                       uint32_t occupancy_,
                       uint32_t m_,
                       uint32_t n_,
                       uint32_t k_)
        : m_per_block(m_per_block_),
          n_per_block(n_per_block_),
          k_per_block(k_per_block_),
          num_cu(num_cu_),
          occupancy(occupancy_),
          m(m_),
          n(n_),
          k(k_)
    {
        init();
    }

    uint32_t get_grid_dims_x() { return dp_start_block_idx + dp_num_blocks; }

    uint32_t get_block_idx(uint32_t bid)
    {
        // block id is linearily allocated along sk blocks (dp blocks are fine)
        // this function will compute blockIdx.x and the linear sk block mapping
        // uint32_t block_idx = 0;
        // if(bid < sk_num_big_blocks) {
        //     uint32_t current_k_iter = bid * k_iters_per_big_block;
        //     tile_idx = current_k_iter / k_iters_per_tile;
        // }
        return bid;
    }

    uint32_t get_current_itr(uint32_t block_idx)
    {
        uint32_t current_itr = 0;
        if(block_idx < sk_num_big_blocks)
        {
            current_itr = block_idx * k_iters_per_big_block;
        }
        else if(block_idx < sk_num_blocks)
        {
            current_itr = (sk_num_big_blocks * k_iters_per_big_block) +
                          (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
        }
        else if(block_idx >= dp_start_block_idx)
        {
            current_itr = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
        }
        return current_itr;
    }

    void get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end)
    {
        if(block_idx < sk_num_big_blocks)
        {
            iter_start = block_idx * k_iters_per_big_block;
            iter_end   = iter_start + k_iters_per_big_block;
        }
        else if(block_idx < sk_num_blocks)
        {
            iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
                         (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
            iter_end = iter_start + (k_iters_per_big_block - 1);
        }
        else if(block_idx >= dp_start_block_idx)
        {
            iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
            iter_end   = iter_start + dp_iters_per_block;
        }
    }

    private:
    void init()
    {
        uint32_t num_tiles =
            impl::integer_divide_ceil(m, m_per_block) * impl::integer_divide_ceil(n, n_per_block);
        k_iters_per_tile = impl::integer_divide_ceil(k, k_per_block);

        // one cu can hold one wg at one time, from the whole chip's point of view
        // if number of wg is same as num_cu, we call it 1 dispatch
        // if number of wg is 2x num_cu, we call it 2 dispatches.
        // one dispatch can deliever wg same as num_cu (full dispatch), or less than num_cu (partial
        // dispatch)
        //
        uint32_t full_dispatches         = num_tiles / num_cu;
        uint32_t full_dispatch_tiles     = full_dispatches * num_cu;
        uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles;

        uint32_t sk_occupancy = occupancy;
        uint32_t dp_tiles     = full_dispatch_tiles;
        uint32_t sk_tiles     = partial_dispatche_tiles;

        if(full_dispatches < occupancy)
        {
            // in this case, we allocate all blocks as sk blocks
            // sk_occupancy = occupancy - full_dispatches;
            sk_occupancy = 1; // TODO: single occ seems better
            dp_tiles     = full_dispatch_tiles;
            sk_tiles     = partial_dispatche_tiles;
        }
        else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
        {
            // e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
            //      occupancy = 3, full_dispatches = 5, 8, 11 ...
            //      occupancy = 4, full_dispatches = 7, 11 ...
            sk_occupancy = 1; // left 1 slot for sk occupancy
            dp_tiles     = full_dispatch_tiles;
            sk_tiles     = partial_dispatche_tiles;
        }
        else
        {
            // others, we reduce 1 dispatch from dp, together with partial dispatch,
            // to construct sk dispatch
            sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
            dp_tiles     = full_dispatch_tiles - num_cu;
            sk_tiles     = partial_dispatche_tiles + num_cu;
        }

        // dp_num_blocks = dp_tiles;
        // dp_start_block_idx = num_cu * sk_occupancy;
        dp_iters_per_block = k_iters_per_tile;

        sk_total_iters = k_iters_per_tile * sk_tiles;

        // printf("num_tiles:%d, full_dispatches:%d, full_dispatch_tiles:%d,
        // partial_dispatche_tiles:%d\n",
        //         num_tiles, full_dispatches, full_dispatch_tiles, partial_dispatche_tiles);

        {
            uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
            uint32_t max_sk_tiles =
                (sk_tiles >= num_cu) ? num_cu * sk_occupancy
                                     : impl::min(num_cu, sk_total_iters / min_k_iters_per_sk_block);

            // if use dp for sk-block, how many iters do we need
            uint32_t dp_for_sk_iters = k_iters_per_tile;

            uint32_t best_sk_score =
                std::numeric_limits<int>::max(); // we need to find the smallest sk iters
            for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
                tentative_sk_blocks++)
            {
                uint32_t tentative_sk_iters_per_block =
                    (sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
                uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
                uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;

                // TODO: carefully adjust this parameter
                //       the more sk_blocks_per_tile, the worse the overhead
                uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
                if(tentative_sk_blocks % sk_tiles != 0)
                {
                    // penalty for uneven divide
                    cross_sk_blocks_overhead +=
                        sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
                }

                uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;

                if(tentative_sk_score < best_sk_score)
                {
                    best_sk_score = tentative_sk_score;
                    sk_num_blocks = tentative_sk_blocks;
                }
            }

            if(best_sk_score >= dp_for_sk_iters)
            {
                sk_num_blocks = 0;
            }

            if(sk_num_blocks == 0)
            {
                sk_num_big_blocks     = 0;
                k_iters_per_big_block = 0;

                dp_num_blocks      = num_tiles; // all tile to be dp block
                dp_start_block_idx = 0;
                sk_total_iters     = 0; // clear this tiles
            }
            else
            {
                uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
                sk_num_big_blocks     = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
                k_iters_per_big_block = k_iters_per_sk_block + 1;

                dp_num_blocks      = dp_tiles;
                dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
            }
        }
    }
};

struct tile_work_t
{
    uint32_t tile_idx;
    uint32_t iter_begin;
    uint32_t k_begin;
    uint32_t k_end;
    uint32_t k_iters_remaining;
};

int main(int argc, char** argv)
{
    simple_args_t arg = create_arg(argc, argv);
    block_dispatcher_t block_dispatcher{arg.get_uint32("m_per_block"),
                                        arg.get_uint32("n_per_block"),
                                        arg.get_uint32("k_per_block"),
                                        arg.get_uint32("num_cu"),
                                        arg.get_uint32("occupancy"),
                                        arg.get_uint32("m"),
                                        arg.get_uint32("n"),
                                        arg.get_uint32("k")};
    block_dispatcher.dump();
    // simulate actual kernel launch
    uint32_t dim_x = block_dispatcher.get_grid_dims_x();
    uint32_t total_k_iters =
        impl::integer_divide_ceil(arg.get_uint32("k"), arg.get_uint32("k_per_block"));
    uint32_t num_tiles =
        impl::integer_divide_ceil(arg.get_uint32("m"), arg.get_uint32("m_per_block")) *
        impl::integer_divide_ceil(arg.get_uint32("n"), arg.get_uint32("n_per_block"));

    std::vector<int> valid_tile_record(num_tiles * total_k_iters);

    for(uint32_t bid = 0; bid < dim_x; bid++)
    {
        uint32_t block_idx = block_dispatcher.get_block_idx(bid);
        bool is_sk_block   = block_idx < (block_dispatcher.sk_num_blocks);
        bool is_dp_block   = block_idx >= block_dispatcher.dp_start_block_idx;
        uint32_t iter_start, iter_end;
        block_dispatcher.get_block_itr(block_idx, iter_start, iter_end);
        uint32_t total_iter_length = iter_end - iter_start;

        while(true)
        {
            uint32_t iter_length_mod = iter_end % block_dispatcher.k_iters_per_tile;
            uint32_t current_iter_length =
                impl::min(iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod,
                          total_iter_length);
            uint32_t tile_idx = (iter_end - 1) / block_dispatcher.k_iters_per_tile;
            uint32_t tile_iter_start =
                ((iter_end - 1) % block_dispatcher.k_iters_per_tile) - current_iter_length + 1;

            if(is_sk_block)
            {
                printf("[sk_block] bid:%3d, block_idx:%3d, tile_idx:%3d, iter_start:%d(%d | %d), "
                       "iter_end:%d (len:%d)\n",
                       bid,
                       block_idx,
                       tile_idx,
                       iter_end - current_iter_length,
                       tile_iter_start,
                       iter_start,
                       iter_end,
                       current_iter_length);
            }
            else if(is_dp_block)
            {
                printf("[dp_block] bid:%3d, block_idx:%3d, tile_idx:%3d, iter_start:%d(%d | %d), "
                       "iter_end:%d (len:%d)\n",
                       bid,
                       block_idx,
                       tile_idx,
                       iter_end - current_iter_length,
                       tile_iter_start,
                       iter_start,
                       iter_end,
                       current_iter_length);
            }
            else
            {
                printf("[other   ] bid:%3d, block_idx:%3d\n", bid, block_idx);
            }

            // some validation check
            for(auto i = iter_end - current_iter_length; i < iter_end; i++)
            {
                if(i >= valid_tile_record.size())
                {
                    printf("unexpected, current iter:%d larger than max:%d\n",
                           i,
                           valid_tile_record.size());
                    return -1;
                }
                valid_tile_record[i] = 1;
            }

            iter_end -= current_iter_length;
            if(iter_end <= iter_start)
                break;
        }
    }

    int untouched = 0;
    for(auto i = 0; i < valid_tile_record.size(); i++)
    {
        if(valid_tile_record[i] != 1)
        {
            printf("untouched at %d (%d)\n", i, valid_tile_record.size());
            untouched++;
        }
    }
    printf("untouched %d/%d, %s\n",
           untouched,
           valid_tile_record.size(),
           untouched == 0 ? "valid" : "fail");
}