work_scheduling.hpp 5.47 KB
Newer Older
1
// SPDX-License-Identifier: MIT
Adam Osewski's avatar
Adam Osewski committed
2
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

#pragma once

#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/workgroup_barrier.hpp"

namespace ck {

enum struct WorkSchedulingPolicy
{
    StridedTileLoop
};

///
/// @brief      This class describes a strided reduction tile loop work scheduling.
///
///
/// @par Overview
///     This work scheduling policy assume linear mapping (with stride) of workgroups along
///     the reduced dimension. In GEMM problem this mean that consecutive workgroups are mapped
///     to strided data tiles along K dimension. This can be obtained using i.e.
///     @see BlockToCTileMap_ReduceKSplit.
///
/// @par Synchronization
///     All workgroups aligned along particular reduced dimension have to reduce their partial
///     results. In order to do that there's a need to use global flags and atomics to communicate
///     between those workgroups.
///
class StridedReductionTileLoop
{
    public:
35
    __device__ StridedReductionTileLoop(index_t tile_count, uint32_t* const __restrict__ p_flags)
36
37
38
39
        : tile_count_{tile_count},
          tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()},
          tile_id_{get_block_1d_id() * tiles_per_block_},
          block_tile_idx_{0},
40
          finished_block_flags_{p_flags}
41
42
43
    {
    }

44
45
46
47
48
    __device__ bool HasTile() const
    {
        return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_;
    }

49
50
51
52
    __device__ bool GetNextTile()
    {
        tile_id_++;
        block_tile_idx_++;
Adam Osewski's avatar
Adam Osewski committed
53
        return HasTile();
54
55
    }

56
57
58
59
60
61
    /// @brief Returns the number of next k-tiles to process.
    /// @param[in]  k_tiles     The number of tiles in the reduced dimension.
    /// @param[in]  k_tile_idx  Current k-tile index.
    /// @return The number of next k-tiles to process.
    __device__ index_t GetNextKTiles(index_t k_tiles, index_t k_tile_idx)
    {
Adam Osewski's avatar
Adam Osewski committed
62
63
64
65
        index_t k_tiles_left     = k_tiles - k_tile_idx;
        index_t block_tiles_left = tiles_per_block_ - block_tile_idx_;
        index_t next_k_tiles = k_tiles_left <= block_tiles_left ? k_tiles_left : block_tiles_left;

66
67
68
69
70
        tile_id_ += next_k_tiles;
        block_tile_idx_ += next_k_tiles;
        return next_k_tiles;
    }

71
    __device__ index_t GetFlagCount() const { return get_grid_size(); }
72

73
    ///
74
    /// @brief      Get this workgroup flag index.
75
76
77
    ///
    /// @return     The workgroup flag index.
    ///
78
    __device__ uint32_t GetWorkgroupFlagIdx() const { return static_cast<uint32_t>(blockIdx.x); }
79
80
81
82

    ///
    /// @brief      Flag each workgroup that has finished its work.
    ///
83
    __device__ void FlagFinished() { finished_block_flags_.inc(GetWorkgroupFlagIdx()); }
84
85
86
87

    ///
    /// @brief      Wait until each workgroup has finished its work.
    ///
88
89
    /// @param[in]  k_tiles     The number of tiles in the reduced dimension.
    /// @param[in]  k_tile_idx  The currently processed tile k index.
90
    ///
91
92
    /// @return     The number of neighbours.
    ///
93
    __device__ index_t WaitForNeighbours(index_t k_tiles, index_t k_tile_idx)
94
    {
95
96
97
98
99
        // We have to wait for all workgroups to finish their partial results.
        // First count how many "neighbour" workgroups we have to check.
        index_t neighbour_count = 0;
        if(tiles_per_block_ < k_tiles)
        {
Adam Osewski's avatar
Adam Osewski committed
100
            // Since we can have deviation (+/-1) in neighbours number
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            // we calculate how many workgroups are needed to process the k-tiles left.
            neighbour_count = (k_tiles - k_tile_idx - 1 + tiles_per_block_ - 1) / tiles_per_block_;
        }
        // If we have more tiles to process than the reduction dimension size,
        // then the number of neighbours depends on first K-tile workgroup block tile idx.
        else
        {
            if(block_tile_idx_ == tiles_per_block_)
            {
                // If we just finished work per workgroup then check at which k-idx we are.
                neighbour_count = (k_tile_idx < k_tiles - 1) ? 1 : 0;
            }
            else
            {
                // If we have still tiles to process then it means that we already processed
                // whole K-dim.
                neighbour_count = 0;
            }
        }

        if(neighbour_count > 0)
        {
123
124
            // Check if this workgroup's flag is also set (i = 0)
            for(index_t i = 0; i <= neighbour_count; ++i)
125
            {
126
127
                finished_block_flags_.wait_eq(GetWorkgroupFlagIdx() + i, 1);
            }
128
129
130
        }

        return neighbour_count;
131
132
133
    }

    ///
134
    /// @brief      Wait until reduction workgroup has finished its work.
135
    ///
136
    __device__ void WaitForReduction()
137
    {
138
139
        // Wait untill my counter has been reset.
        finished_block_flags_.wait_eq(GetWorkgroupFlagIdx(), 0);
140
141
142
143
144
    }

    ///
    /// @brief      Reset flag counter to zero.
    ///
145
    /// @param[in]  neighbour_count     The number of peer workgroups.
146
    ///
147
    __device__ void Reset(index_t neighbour_count)
148
    {
149
150
151
152
        for(index_t i = 0; i <= neighbour_count; ++i)
        {
            finished_block_flags_.reset(GetWorkgroupFlagIdx() + i);
        }
153
154
155
156
157
    }

    ///
    /// @brief      Gets the flag value.
    ///
158
    __device__ uint32_t GetFlagValue() const
159
    {
160
        return finished_block_flags_.ld(GetWorkgroupFlagIdx());
161
162
163
164
165
166
167
168
169
170
    }

    const index_t tile_count_;
    const index_t tiles_per_block_;
    index_t tile_id_;
    index_t block_tile_idx_;
    workgroup_barrier finished_block_flags_;
};

} // namespace ck