work_scheduling.hpp 6.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/workgroup_barrier.hpp"

namespace ck {

enum struct WorkSchedulingPolicy
{
    StridedTileLoop
};

///
/// @brief      This class describes a strided reduction tile loop work scheduling.
///
///
/// @par Overview
///     This work scheduling policy assume linear mapping (with stride) of workgroups along
///     the reduced dimension. In GEMM problem this mean that consecutive workgroups are mapped
///     to strided data tiles along K dimension. This can be obtained using i.e.
///     @see BlockToCTileMap_ReduceKSplit.
///
/// @par Synchronization
///     All workgroups aligned along particular reduced dimension have to reduce their partial
///     results. In order to do that there's a need to use global flags and atomics to communicate
///     between those workgroups.
///
class StridedReductionTileLoop
{
    public:
    __device__ StridedReductionTileLoop(index_t tile_count,
                                        uint32_t* const __restrict__ p_flag_count)
        : tile_count_{tile_count},
          tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()},
          tile_id_{get_block_1d_id() * tiles_per_block_},
          block_tile_idx_{0},
          finished_block_flags_{p_flag_count}
    {
    }

45
46
47
48
49
    __device__ bool HasTile() const
    {
        return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_;
    }

50
51
52
53
    __device__ bool GetNextTile()
    {
        tile_id_++;
        block_tile_idx_++;
Adam Osewski's avatar
Adam Osewski committed
54
        return HasTile();
55
56
    }

57
58
59
60
61
62
63
    __device__ index_t GetFlagCount(index_t k_tiles) const
    {
        // This is the number of MN-output tiles which we cover with workgroups.
        // We launch k_tiles (k_batch) / tiles_per_block workgroups for each output tile.
        return (get_grid_size() * tiles_per_block_ + k_tiles - 1) / k_tiles;
    }

64
65
66
67
    ///
    /// @brief      Calculate this workgroup flag index.
    ///
    /// @note       Note this scheduler intentionaly does not have flag index as its member, since
68
69
    ///             current workgroup may process tiles across different MN-output tiles or
    ///             acorss different GEMMs (grouped gemm).
70
    ///
71
72
73
74
    /// @param[in]  k_tiles                 The number of data tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) linear tile index (of current GEMM).
    /// @param[in]  output_tile_idx_offset  The accumulated offset of output tiles from previous
    ///                                     GEMMs.
75
76
77
    ///
    /// @return     The workgroup flag index.
    ///
Adam Osewski's avatar
Adam Osewski committed
78
    __device__ uint32_t GetWorkgroupFlagIdx([[maybe_unused]] index_t k_tiles,
79
80
                                            index_t output_tile_idx,
                                            index_t output_tile_idx_offset) const
81
    {
Adam Osewski's avatar
Adam Osewski committed
82
83
        // return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles);
        return output_tile_idx + output_tile_idx_offset;
84
85
86
87
88
    }

    ///
    /// @brief      Flag each workgroup that has finished its work.
    ///
89
90
91
92
93
94
95
    /// @param[in]  k_tiles               The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index
    /// @param[in]  output_tile_idx_offset  The output tile index offset
    ///
    __device__ void
    FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
    {
Adam Osewski's avatar
Adam Osewski committed
96
        /* [[maybe_unused]]  */const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
97
        finished_block_flags_.inc(fidx);
98
99
100
101
102
103
104
105
    }

    ///
    /// @brief      Wait until each workgroup has finished its work.
    ///
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index
    /// @param[in]  output_tile_idx_offset  The output tile index offset
106
    ///
107
108
    __device__ void
    WaitForNeighbours(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
109
    {
110
111
112
113
114
        // Wait untill all workgroups finish
        const index_t workgroups_per_dim = (k_tiles + tiles_per_block_ - 1) / tiles_per_block_;
        // We use < because for some cases we may have +1 more workgroups per dim.
        // Ie when k_tiles = 5, tiles_per_block = 3.
        finished_block_flags_.wait_lt(
Adam Osewski's avatar
Adam Osewski committed
115
        GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset),
116
            workgroups_per_dim);
Adam Osewski's avatar
Adam Osewski committed
117
118

        // [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
119
120
121
122
123
    }

    ///
    /// @brief      Wait until each workgroup has finished its work.
    ///
124
125
126
127
128
129
130
131
132
133
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index
    /// @param[in]  output_tile_idx_offset  The output tile index offset
    ///
    __device__ void
    WaitForReduction(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
    {
        // Wait untill the counter has been reset.
        finished_block_flags_.wait_eq(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), 0);
Adam Osewski's avatar
Adam Osewski committed
134
135

        // [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
136
137
138
139
140
141
142
143
144
145
146
147
148
    }

    ///
    /// @brief      Reset flag counter to zero.
    ///
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index.
    /// @param[in]  output_tile_idx_offset  The output tile index offset.
    ///
    __device__ void Reset(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
    {
        finished_block_flags_.reset(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
Adam Osewski's avatar
Adam Osewski committed
149
150

        // [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
151
152
153
154
155
156
157
158
    }

    ///
    /// @brief      Gets the flag value.
    ///
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index.
    /// @param[in]  output_tile_idx_offset  The output tile index offset.
159
    ///
160
161
162
    __device__ uint32_t GetFlagValue(index_t k_tiles,
                                     index_t output_tile_idx,
                                     index_t output_tile_idx_offset) const
163
    {
164
165
        return finished_block_flags_.ld(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
166
167
168
169
170
171
172
173
174
175
    }

    const index_t tile_count_;
    const index_t tiles_per_block_;
    index_t tile_id_;
    index_t block_tile_idx_;
    workgroup_barrier finished_block_flags_;
};

} // namespace ck