work_scheduling.hpp 5.93 KB
Newer Older
1
// SPDX-License-Identifier: MIT
Adam Osewski's avatar
Adam Osewski committed
2
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

#pragma once

#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/workgroup_barrier.hpp"

namespace ck {

enum struct WorkSchedulingPolicy
{
    StridedTileLoop
};

///
/// @brief      This class describes a strided reduction tile loop work scheduling.
///
///
/// @par Overview
///     This work scheduling policy assume linear mapping (with stride) of workgroups along
///     the reduced dimension. In GEMM problem this mean that consecutive workgroups are mapped
///     to strided data tiles along K dimension. This can be obtained using i.e.
///     @see BlockToCTileMap_ReduceKSplit.
///
/// @par Synchronization
///     All workgroups aligned along particular reduced dimension have to reduce their partial
///     results. In order to do that there's a need to use global flags and atomics to communicate
///     between those workgroups.
///
class StridedReductionTileLoop
{
    public:
35
    __device__ StridedReductionTileLoop(index_t tile_count, uint32_t* const __restrict__ p_flags)
36
37
38
39
        : tile_count_{tile_count},
          tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()},
          tile_id_{get_block_1d_id() * tiles_per_block_},
          block_tile_idx_{0},
40
          finished_block_flags_{p_flags}
41
42
43
    {
    }

44
45
46
47
48
    __device__ bool HasTile() const
    {
        return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_;
    }

49
50
51
52
    __device__ bool GetNextTile()
    {
        tile_id_++;
        block_tile_idx_++;
Adam Osewski's avatar
Adam Osewski committed
53
        return HasTile();
54
55
    }

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    /// @brief Returns the number of next k-tiles to process.
    /// @param[in]  k_tiles     The number of tiles in the reduced dimension.
    /// @param[in]  k_tile_idx  Current k-tile index.
    /// @return The number of next k-tiles to process.
    __device__ index_t GetNextKTiles(index_t k_tiles, index_t k_tile_idx)
    {
        index_t k_tiles_left = k_tiles - k_tile_idx;
        index_t next_k_tiles =
            k_tiles_left <= tiles_per_block_ ? k_tiles_left : tiles_per_block_ - block_tile_idx_;
        tile_id_ += next_k_tiles;
        block_tile_idx_ += next_k_tiles;

        if(blockIdx.x < 4 && ck::debug::is_thread_local_1d_id_idx<0>())
        {
            printf("[GetNextKTiles] bid: %d, k_tiles: %d, k_idx:%d, next_k_tiles: %d, "
                   "k_tiles_left: %d,"
                   " tile_id: %d, block_tile_idx: %d\n",
                   static_cast<index_t>(blockIdx.x),
                   k_tiles,
                   k_tile_idx,
                   next_k_tiles,
                   k_tiles_left,
                   tile_id_,
                   block_tile_idx_);
        }

        return next_k_tiles;
    }

85
    __device__ index_t GetFlagCount() const { return get_grid_size(); }
86

87
    ///
88
    /// @brief      Get this workgroup flag index.
89
90
91
    ///
    /// @return     The workgroup flag index.
    ///
92
    __device__ uint32_t GetWorkgroupFlagIdx() const { return static_cast<uint32_t>(blockIdx.x); }
93
94
95
96

    ///
    /// @brief      Flag each workgroup that has finished its work.
    ///
97
    __device__ void FlagFinished() { finished_block_flags_.inc(GetWorkgroupFlagIdx()); }
98
99
100
101

    ///
    /// @brief      Wait until each workgroup has finished its work.
    ///
102
103
    /// @param[in]  k_tiles     The number of tiles in the reduced dimension.
    /// @param[in]  k_tile_idx  The currently processed tile k index.
104
    ///
105
106
    /// @return     The number of neighbours.
    ///
107
    __device__ index_t WaitForNeighbours(index_t k_tiles, index_t k_tile_idx)
108
    {
109
110
111
112
113
        // We have to wait for all workgroups to finish their partial results.
        // First count how many "neighbour" workgroups we have to check.
        index_t neighbour_count = 0;
        if(tiles_per_block_ < k_tiles)
        {
Adam Osewski's avatar
Adam Osewski committed
114
            // Since we can have deviation (+/-1) in neighbours number
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            // we calculate how many workgroups are needed to process the k-tiles left.
            neighbour_count = (k_tiles - k_tile_idx - 1 + tiles_per_block_ - 1) / tiles_per_block_;
        }
        // If we have more tiles to process than the reduction dimension size,
        // then the number of neighbours depends on first K-tile workgroup block tile idx.
        else
        {
            if(block_tile_idx_ == tiles_per_block_)
            {
                // If we just finished work per workgroup then check at which k-idx we are.
                neighbour_count = (k_tile_idx < k_tiles - 1) ? 1 : 0;
            }
            else
            {
                // If we have still tiles to process then it means that we already processed
                // whole K-dim.
                neighbour_count = 0;
            }
        }

        if(neighbour_count > 0)
        {
137
138
            // Check if this workgroup's flag is also set (i = 0)
            for(index_t i = 0; i <= neighbour_count; ++i)
139
            {
140
141
                finished_block_flags_.wait_eq(GetWorkgroupFlagIdx() + i, 1);
            }
142
143
144
        }

        return neighbour_count;
145
146
147
    }

    ///
148
    /// @brief      Wait until reduction workgroup has finished its work.
149
    ///
150
    __device__ void WaitForReduction()
151
    {
152
153
        // Wait untill my counter has been reset.
        finished_block_flags_.wait_eq(GetWorkgroupFlagIdx(), 0);
154
155
156
157
158
    }

    ///
    /// @brief      Reset flag counter to zero.
    ///
159
    /// @param[in]  neighbour_count     The number of peer workgroups.
160
    ///
161
    __device__ void Reset(index_t neighbour_count)
162
    {
163
164
165
166
        for(index_t i = 0; i <= neighbour_count; ++i)
        {
            finished_block_flags_.reset(GetWorkgroupFlagIdx() + i);
        }
167
168
169
170
171
    }

    ///
    /// @brief      Gets the flag value.
    ///
172
    __device__ uint32_t GetFlagValue() const
173
    {
174
        return finished_block_flags_.ld(GetWorkgroupFlagIdx());
175
176
177
178
179
180
181
182
183
184
    }

    const index_t tile_count_;
    const index_t tiles_per_block_;
    index_t tile_id_;
    index_t block_tile_idx_;
    workgroup_barrier finished_block_flags_;
};

} // namespace ck