"vscode:/vscode.git/clone" did not exist on "a2aea59b6e191bcf5dc9fb6c862e5496ca8debd4"
work_scheduling.hpp 6.21 KB
Newer Older
1
// SPDX-License-Identifier: MIT
Adam Osewski's avatar
Adam Osewski committed
2
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

#pragma once

#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/workgroup_barrier.hpp"

namespace ck {

enum struct WorkSchedulingPolicy
{
    StridedTileLoop
};

///
/// @brief      This class describes a strided reduction tile loop work scheduling.
///
///
/// @par Overview
///     This work scheduling policy assume linear mapping (with stride) of workgroups along
///     the reduced dimension. In GEMM problem this mean that consecutive workgroups are mapped
///     to strided data tiles along K dimension. This can be obtained using i.e.
///     @see BlockToCTileMap_ReduceKSplit.
///
/// @par Synchronization
///     All workgroups aligned along particular reduced dimension have to reduce their partial
///     results. In order to do that there's a need to use global flags and atomics to communicate
///     between those workgroups.
///
class StridedReductionTileLoop
{
    public:
35
    __device__ StridedReductionTileLoop(index_t tile_count, uint32_t* const __restrict__ p_flags)
36
37
38
39
40
41
42
        : tile_count_{__builtin_amdgcn_readfirstlane(tile_count)},
          tiles_per_block_{__builtin_amdgcn_readfirstlane((tile_count_ + get_grid_size() - 1) /
                                                          get_grid_size())},
          tile_id_{__builtin_amdgcn_readfirstlane(get_block_1d_id() * tiles_per_block_)},
          block_tile_idx_{__builtin_amdgcn_readfirstlane(0)},
          finished_block_flags_{p_flags},
          is_sync_needed_{1}
43
44
45
    {
    }

46
47
48
49
50
    __device__ bool HasTile() const
    {
        return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_;
    }

51
52
53
54
    __device__ bool GetNextTile()
    {
        tile_id_++;
        block_tile_idx_++;
Adam Osewski's avatar
Adam Osewski committed
55
        return HasTile();
56
57
    }

58
59
60
61
62
63
    /// @brief Returns the number of next k-tiles to process.
    /// @param[in]  k_tiles     The number of tiles in the reduced dimension.
    /// @param[in]  k_tile_idx  Current k-tile index.
    /// @return The number of next k-tiles to process.
    __device__ index_t GetNextKTiles(index_t k_tiles, index_t k_tile_idx)
    {
Adam Osewski's avatar
Adam Osewski committed
64
65
66
67
        index_t k_tiles_left     = k_tiles - k_tile_idx;
        index_t block_tiles_left = tiles_per_block_ - block_tile_idx_;
        index_t next_k_tiles = k_tiles_left <= block_tiles_left ? k_tiles_left : block_tiles_left;

68
69
70
71
72
        tile_id_ += next_k_tiles;
        block_tile_idx_ += next_k_tiles;
        return next_k_tiles;
    }

73
    __device__ index_t GetFlagCount() const { return get_grid_size(); }
74

75
    ///
76
    /// @brief      Get this workgroup flag index.
77
78
79
    ///
    /// @return     The workgroup flag index.
    ///
80
    __device__ uint32_t GetWorkgroupFlagIdx() const { return static_cast<uint32_t>(blockIdx.x); }
81
82
83
84

    ///
    /// @brief      Flag each workgroup that has finished its work.
    ///
85
86
87
88
89
    __device__ void FlagFinished()
    {
        if(is_sync_needed_)
            finished_block_flags_.inc(GetWorkgroupFlagIdx());
    }
90
91
92
93

    ///
    /// @brief      Wait until each workgroup has finished its work.
    ///
94
95
96
    /// @note       This function assumes it's called by the WGP which processes the first
    ///             k-tile.
    ///
97
98
    /// @param[in]  k_tiles     The number of tiles in the reduced dimension.
    /// @param[in]  k_tile_idx  The currently processed tile k index.
99
    ///
100
101
    /// @return     The number of neighbours.
    ///
102
    __device__ index_t WaitForNeighbours(index_t k_tiles, index_t k_tile_idx)
103
    {
104
105
106
        // We have to wait for all workgroups to finish their partial results.
        // First count how many "neighbour" workgroups we have to check.
        index_t neighbour_count = 0;
107
108
109
110

        if(!is_sync_needed_)
            return neighbour_count;

111
112
        if(tiles_per_block_ < k_tiles)
        {
Adam Osewski's avatar
Adam Osewski committed
113
            // Since we can have deviation (+/-1) in neighbours number
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
            // we calculate how many workgroups are needed to process the k-tiles left.
            neighbour_count = (k_tiles - k_tile_idx - 1 + tiles_per_block_ - 1) / tiles_per_block_;
        }
        // If we have more tiles to process than the reduction dimension size,
        // then the number of neighbours depends on first K-tile workgroup block tile idx.
        else
        {
            if(block_tile_idx_ == tiles_per_block_)
            {
                // If we just finished work per workgroup then check at which k-idx we are.
                neighbour_count = (k_tile_idx < k_tiles - 1) ? 1 : 0;
            }
            else
            {
                // If we have still tiles to process then it means that we already processed
                // whole K-dim.
                neighbour_count = 0;
            }
        }

        if(neighbour_count > 0)
        {
136
137
            // Check if this workgroup's flag is also set (i = 0)
            for(index_t i = 0; i <= neighbour_count; ++i)
138
            {
139
140
                finished_block_flags_.wait_eq(GetWorkgroupFlagIdx() + i, 1);
            }
141
142
143
        }

        return neighbour_count;
144
145
146
    }

    ///
147
    /// @brief      Wait until reduction workgroup has finished its work.
148
    ///
149
    __device__ void WaitForReduction()
150
    {
151
152
153
        if(is_sync_needed_)
            // Wait untill my counter has been reset.
            finished_block_flags_.wait_eq(GetWorkgroupFlagIdx(), 0);
154
155
156
157
158
    }

    ///
    /// @brief      Reset flag counter to zero.
    ///
159
    /// @param[in]  neighbour_count     The number of peer workgroups.
160
    ///
161
    __device__ void Reset(index_t neighbour_count)
162
    {
163
        if(is_sync_needed_)
164
        {
165
166
167
168
            for(index_t i = 0; i <= neighbour_count; ++i)
            {
                finished_block_flags_.reset(GetWorkgroupFlagIdx() + i);
            }
169
        }
170
171
172
173
174
    }

    ///
    /// @brief      Gets the flag value.
    ///
175
    __device__ uint32_t GetFlagValue() const
176
    {
177
        return finished_block_flags_.ld(GetWorkgroupFlagIdx());
178
179
    }

180
181
182
183
184
    __device__ void SetIsSyncNeeded(index_t next_k_tiles, index_t k_tiles)
    {
        is_sync_needed_ = __builtin_amdgcn_readfirstlane(next_k_tiles == k_tiles ? 0 : 1);
    }

185
186
187
188
189
    const index_t tile_count_;
    const index_t tiles_per_block_;
    index_t tile_id_;
    index_t block_tile_idx_;
    workgroup_barrier finished_block_flags_;
190
    index_t is_sync_needed_;
191
192
193
};

} // namespace ck