work_scheduling.hpp 3.74 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/workgroup_barrier.hpp"

namespace ck {

enum struct WorkSchedulingPolicy
{
    StridedTileLoop
};

///
/// @brief      This class describes a strided reduction tile loop work scheduling.
///
///
/// @par Overview
///     This work scheduling policy assume linear mapping (with stride) of workgroups along
///     the reduced dimension. In GEMM problem this mean that consecutive workgroups are mapped
///     to strided data tiles along K dimension. This can be obtained using i.e.
///     @see BlockToCTileMap_ReduceKSplit.
///
/// @par Synchronization
///     All workgroups aligned along particular reduced dimension have to reduce their partial
///     results. In order to do that there's a need to use global flags and atomics to communicate
///     between those workgroups.
///
class StridedReductionTileLoop
{
    public:
    __device__ StridedReductionTileLoop(index_t tile_count,
                                        uint32_t* const __restrict__ p_flag_count)
        : tile_count_{tile_count},
          tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()},
          tile_id_{get_block_1d_id() * tiles_per_block_},
          block_tile_idx_{0},
          finished_block_flags_{p_flag_count}
    {
    }

    __device__ bool GetNextTile()
    {
        tile_id_++;
        block_tile_idx_++;
        return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_;
    }

    ///
    /// @brief      Calculate this workgroup flag index.
    ///
    /// @note       Note this scheduler intentionaly does not have flag index as its member, since
    ///             the number of `dim_tiles` may change when iterating (ie. in grouped gemm,
    ///             different groups may have different `dim_tiles` in K dimension).
    ///
    /// @param[in]  dim_tiles        The number of data tiles in the reduced dimension.
    /// @param[in]  output_tile_idx  The output (MN) tile index.
    ///
    /// @return     The workgroup flag index.
    ///
    __device__ index_t GetWorkgroupFlagIdx(index_t dim_tiles, index_t output_tile_idx) const
    {
        // This is the number of MN-output tiles which we cover with workgroups.
        // We launch dim_tiles (k_batch) / tiles_per_block workgroups for each output tile.
        const index_t flag_count = (get_grid_size() * tiles_per_block_ + dim_tiles - 1) / dim_tiles;
        return output_tile_idx % flag_count;
    }

    ///
    /// @brief      Flag each workgroup that has finished its work.
    ///
    /// @param[in]  dim_tiles        The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx  The output (MN) tile index
    ///
    __device__ void FlagFinished(index_t dim_tiles, index_t output_tile_idx)
    {
        finished_block_flags_.inc(GetWorkgroupFlagIdx(dim_tiles, output_tile_idx));
    }

    ///
    /// @brief      Wait until each workgroup has finished its work.
    ///
    /// @param[in]  dim_tiles        The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx  The output (MN) tile index
    ///
    __device__ void WaitForNeighbours(index_t dim_tiles, index_t output_tile_idx)
    {
        // Wait untill all workgroups finish and reset counter.
        const index_t workgroups_per_dim = (dim_tiles + tiles_per_block_ - 1) / tiles_per_block_;
        finished_block_flags_.wait_set(
            GetWorkgroupFlagIdx(dim_tiles, output_tile_idx), workgroups_per_dim, 0);
    }

    const index_t tile_count_;
    const index_t tiles_per_block_;
    index_t tile_id_;
    index_t block_tile_idx_;
    workgroup_barrier finished_block_flags_;
};

} // namespace ck