work_scheduling.hpp 7.65 KB
Newer Older
1
// SPDX-License-Identifier: MIT
Adam Osewski's avatar
Adam Osewski committed
2
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

#pragma once

#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/workgroup_barrier.hpp"

namespace ck {

enum struct WorkSchedulingPolicy
{
    StridedTileLoop
};

///
/// @brief      This class describes a strided reduction tile loop work scheduling.
///
///
/// @par Overview
///     This work scheduling policy assume linear mapping (with stride) of workgroups along
///     the reduced dimension. In GEMM problem this mean that consecutive workgroups are mapped
///     to strided data tiles along K dimension. This can be obtained using i.e.
///     @see BlockToCTileMap_ReduceKSplit.
///
/// @par Synchronization
///     All workgroups aligned along particular reduced dimension have to reduce their partial
///     results. In order to do that there's a need to use global flags and atomics to communicate
///     between those workgroups.
///
class StridedReductionTileLoop
{
    public:
    __device__ StridedReductionTileLoop(index_t tile_count,
36
                                        volatile uint32_t* const __restrict__ p_flags)
37
38
39
40
        : tile_count_{tile_count},
          tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()},
          tile_id_{get_block_1d_id() * tiles_per_block_},
          block_tile_idx_{0},
41
          finished_block_flags_{p_flags}
42
43
44
    {
    }

45
46
47
48
49
    __device__ bool HasTile() const
    {
        return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_;
    }

50
51
52
53
    __device__ bool GetNextTile()
    {
        tile_id_++;
        block_tile_idx_++;
Adam Osewski's avatar
Adam Osewski committed
54
        return HasTile();
55
56
    }

57
58
59
60
61
62
63
    __device__ index_t GetFlagCount(index_t k_tiles) const
    {
        // This is the number of MN-output tiles which we cover with workgroups.
        // We launch k_tiles (k_batch) / tiles_per_block workgroups for each output tile.
        return (get_grid_size() * tiles_per_block_ + k_tiles - 1) / k_tiles;
    }

64
65
66
67
    ///
    /// @brief      Calculate this workgroup flag index.
    ///
    /// @note       Note this scheduler intentionaly does not have flag index as its member, since
68
69
    ///             current workgroup may process tiles across different MN-output tiles or
    ///             acorss different GEMMs (grouped gemm).
70
    ///
71
72
73
74
    /// @param[in]  k_tiles                 The number of data tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) linear tile index (of current GEMM).
    /// @param[in]  output_tile_idx_offset  The accumulated offset of output tiles from previous
    ///                                     GEMMs.
75
76
77
    ///
    /// @return     The workgroup flag index.
    ///
Adam Osewski's avatar
Adam Osewski committed
78
    __device__ uint32_t GetWorkgroupFlagIdx(index_t k_tiles,
79
80
                                            index_t output_tile_idx,
                                            index_t output_tile_idx_offset) const
81
    {
82
        return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles);
83
84
85
86
87
    }

    ///
    /// @brief      Flag each workgroup that has finished its work.
    ///
88
89
90
91
92
93
94
    /// @param[in]  k_tiles               The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index
    /// @param[in]  output_tile_idx_offset  The output tile index offset
    ///
    __device__ void
    FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
    {
95
        const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
96
        finished_block_flags_.inc(fidx);
97
98
99
100
101
102
    }

    ///
    /// @brief      Wait until each workgroup has finished its work.
    ///
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
103
    /// @param[in]  k_tile_idx              The currently processed tile k index.
104
105
    /// @param[in]  output_tile_idx         The output (MN) tile index
    /// @param[in]  output_tile_idx_offset  The output tile index offset
106
    ///
107
108
109
110
111
112
    /// @return     The number of neighbours.
    ///
    __device__ index_t WaitForNeighbours(index_t k_tiles,
                                         index_t k_tile_idx,
                                         index_t output_tile_idx,
                                         index_t output_tile_idx_offset)
113
    {
114
115
116
117
118
        // We have to wait for all workgroups to finish their partial results.
        // First count how many "neighbour" workgroups we have to check.
        index_t neighbour_count = 0;
        if(tiles_per_block_ < k_tiles)
        {
Adam Osewski's avatar
Adam Osewski committed
119
            // Since we can have deviation (+/-1) in neighbours number
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
            // we calculate how many workgroups are needed to process the k-tiles left.
            neighbour_count = (k_tiles - k_tile_idx - 1 + tiles_per_block_ - 1) / tiles_per_block_;
        }
        // If we have more tiles to process than the reduction dimension size,
        // then the number of neighbours depends on first K-tile workgroup block tile idx.
        else
        {
            if(block_tile_idx_ == tiles_per_block_)
            {
                // If we just finished work per workgroup then check at which k-idx we are.
                neighbour_count = (k_tile_idx < k_tiles - 1) ? 1 : 0;
            }
            else
            {
                // If we have still tiles to process then it means that we already processed
                // whole K-dim.
                neighbour_count = 0;
            }
        }

        if(neighbour_count > 0)
        {
Adam Osewski's avatar
Adam Osewski committed
142
143
144
            // Also count this workgroup
            neighbour_count++;
            finished_block_flags_.wait_eq(
145
146
147
148
149
                GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset),
                neighbour_count);
        }

        return neighbour_count;
150
151
152
153
154
    }

    ///
    /// @brief      Wait until each workgroup has finished its work.
    ///
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index
    /// @param[in]  output_tile_idx_offset  The output tile index offset
    ///
    __device__ void
    WaitForReduction(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
    {
        // Wait untill the counter has been reset.
        finished_block_flags_.wait_eq(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), 0);
    }

    ///
    /// @brief      Reset flag counter to zero.
    ///
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index.
    /// @param[in]  output_tile_idx_offset  The output tile index offset.
    ///
    __device__ void Reset(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
    {
        finished_block_flags_.reset(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
    }

    ///
    /// @brief      Gets the flag value.
    ///
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index.
    /// @param[in]  output_tile_idx_offset  The output tile index offset.
186
    ///
187
188
189
    __device__ uint32_t GetFlagValue(index_t k_tiles,
                                     index_t output_tile_idx,
                                     index_t output_tile_idx_offset) const
190
    {
191
192
        return finished_block_flags_.ld(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
193
194
195
196
197
198
199
200
201
202
    }

    const index_t tile_count_;
    const index_t tiles_per_block_;
    index_t tile_id_;
    index_t block_tile_idx_;
    workgroup_barrier finished_block_flags_;
};

} // namespace ck