work_scheduling.hpp 6.34 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/workgroup_barrier.hpp"

namespace ck {

enum struct WorkSchedulingPolicy
{
    StridedTileLoop
};

///
/// @brief      This class describes a strided reduction tile loop work scheduling.
///
///
/// @par Overview
///     This work scheduling policy assume linear mapping (with stride) of workgroups along
///     the reduced dimension. In GEMM problem this mean that consecutive workgroups are mapped
///     to strided data tiles along K dimension. This can be obtained using i.e.
///     @see BlockToCTileMap_ReduceKSplit.
///
/// @par Synchronization
///     All workgroups aligned along particular reduced dimension have to reduce their partial
///     results. In order to do that there's a need to use global flags and atomics to communicate
///     between those workgroups.
///
class StridedReductionTileLoop
{
    public:
    __device__ StridedReductionTileLoop(index_t tile_count,
                                        uint32_t* const __restrict__ p_flag_count)
        : tile_count_{tile_count},
          tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()},
          tile_id_{get_block_1d_id() * tiles_per_block_},
          block_tile_idx_{0},
          finished_block_flags_{p_flag_count}
    {
    }

45
46
47
48
49
    __device__ bool HasTile() const
    {
        return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_;
    }

50
51
52
53
54
55
56
    __device__ bool GetNextTile()
    {
        tile_id_++;
        block_tile_idx_++;
        return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_;
    }

57
58
59
60
61
62
63
    __device__ index_t GetFlagCount(index_t k_tiles) const
    {
        // This is the number of MN-output tiles which we cover with workgroups.
        // We launch k_tiles (k_batch) / tiles_per_block workgroups for each output tile.
        return (get_grid_size() * tiles_per_block_ + k_tiles - 1) / k_tiles;
    }

64
65
66
67
    ///
    /// @brief      Calculate this workgroup flag index.
    ///
    /// @note       Note this scheduler intentionaly does not have flag index as its member, since
68
69
    ///             the number of `k_tiles` may change when iterating (ie. in grouped gemm,
    ///             different groups may have different `k_tiles` in K dimension).
70
    ///
71
72
73
    /// @param[in]  k_tiles               The number of data tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index (of current GEMM).
    /// @param[in]  output_tile_idx_offset  The output tile index offset.
74
75
76
    ///
    /// @return     The workgroup flag index.
    ///
77
78
79
    __device__ index_t GetWorkgroupFlagIdx(index_t k_tiles,
                                           index_t output_tile_idx,
                                           index_t output_tile_idx_offset) const
80
    {
81
        return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles);
82
83
84
85
86
    }

    ///
    /// @brief      Flag each workgroup that has finished its work.
    ///
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    /// @param[in]  k_tiles               The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index
    /// @param[in]  output_tile_idx_offset  The output tile index offset
    ///
    __device__ void
    FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
    {
        finished_block_flags_.inc(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
    }

    ///
    /// @brief      Wait until each workgroup has finished its work.
    ///
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index
    /// @param[in]  output_tile_idx_offset  The output tile index offset
104
    ///
105
106
    __device__ void
    WaitForNeighbours(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
107
    {
108
109
110
111
112
113
114
        // Wait untill all workgroups finish
        const index_t workgroups_per_dim = (k_tiles + tiles_per_block_ - 1) / tiles_per_block_;
        // We use < because for some cases we may have +1 more workgroups per dim.
        // Ie when k_tiles = 5, tiles_per_block = 3.
        finished_block_flags_.wait_lt(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset),
            workgroups_per_dim);
115
116
117
118
119
    }

    ///
    /// @brief      Wait until each workgroup has finished its work.
    ///
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index
    /// @param[in]  output_tile_idx_offset  The output tile index offset
    ///
    __device__ void
    WaitForReduction(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
    {
        // Wait untill the counter has been reset.
        finished_block_flags_.wait_eq(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), 0);
    }

    ///
    /// @brief      Reset flag counter to zero.
    ///
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index.
    /// @param[in]  output_tile_idx_offset  The output tile index offset.
    ///
    __device__ void Reset(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
    {
        finished_block_flags_.reset(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
    }

    ///
    /// @brief      Gets the flag value.
    ///
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index.
    /// @param[in]  output_tile_idx_offset  The output tile index offset.
151
    ///
152
153
154
    __device__ index_t GetFlagValue(index_t k_tiles,
                                    index_t output_tile_idx,
                                    index_t output_tile_idx_offset) const
155
    {
156
157
        return static_cast<index_t>(finished_block_flags_.ld(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset)));
158
159
160
161
162
163
164
165
166
167
    }

    const index_t tile_count_;
    const index_t tiles_per_block_;
    index_t tile_id_;
    index_t block_tile_idx_;
    workgroup_barrier finished_block_flags_;
};

} // namespace ck