work_scheduling.hpp 6.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/workgroup_barrier.hpp"

namespace ck {

enum struct WorkSchedulingPolicy
{
    StridedTileLoop
};

///
/// @brief      This class describes a strided reduction tile loop work scheduling.
///
///
/// @par Overview
///     This work scheduling policy assume linear mapping (with stride) of workgroups along
///     the reduced dimension. In GEMM problem this mean that consecutive workgroups are mapped
///     to strided data tiles along K dimension. This can be obtained using i.e.
///     @see BlockToCTileMap_ReduceKSplit.
///
/// @par Synchronization
///     All workgroups aligned along particular reduced dimension have to reduce their partial
///     results. In order to do that there's a need to use global flags and atomics to communicate
///     between those workgroups.
///
class StridedReductionTileLoop
{
    public:
    __device__ StridedReductionTileLoop(index_t tile_count,
                                        uint32_t* const __restrict__ p_flag_count)
        : tile_count_{tile_count},
          tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()},
          tile_id_{get_block_1d_id() * tiles_per_block_},
          block_tile_idx_{0},
          finished_block_flags_{p_flag_count}
    {
    }

45
46
47
48
49
    __device__ bool HasTile() const
    {
        return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_;
    }

50
51
52
53
54
55
56
57
58
59
60
    __device__ bool GetNextTile()
    {
        tile_id_++;
        block_tile_idx_++;
        return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_;
    }

    ///
    /// @brief      Calculate this workgroup flag index.
    ///
    /// @note       Note this scheduler intentionaly does not have flag index as its member, since
61
62
    ///             the number of `k_tiles` may change when iterating (ie. in grouped gemm,
    ///             different groups may have different `k_tiles` in K dimension).
63
    ///
64
65
66
    /// @param[in]  k_tiles               The number of data tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index (of current GEMM).
    /// @param[in]  output_tile_idx_offset  The output tile index offset.
67
68
69
    ///
    /// @return     The workgroup flag index.
    ///
70
71
72
    __device__ index_t GetWorkgroupFlagIdx(index_t k_tiles,
                                           index_t output_tile_idx,
                                           index_t output_tile_idx_offset) const
73
74
    {
        // This is the number of MN-output tiles which we cover with workgroups.
75
76
77
        // We launch k_tiles (k_batch) / tiles_per_block workgroups for each output tile.
        const index_t flag_count = (get_grid_size() * tiles_per_block_ + k_tiles - 1) / k_tiles;
        return (output_tile_idx + output_tile_idx_offset) % flag_count;
78
79
80
81
82
    }

    ///
    /// @brief      Flag each workgroup that has finished its work.
    ///
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    /// @param[in]  k_tiles               The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index
    /// @param[in]  output_tile_idx_offset  The output tile index offset
    ///
    __device__ void
    FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
    {
        finished_block_flags_.inc(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
    }

    ///
    /// @brief      Wait until each workgroup has finished its work.
    ///
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index
    /// @param[in]  output_tile_idx_offset  The output tile index offset
100
    ///
101
102
    __device__ void
    WaitForNeighbours(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
103
    {
104
105
106
107
108
109
110
        // Wait untill all workgroups finish
        const index_t workgroups_per_dim = (k_tiles + tiles_per_block_ - 1) / tiles_per_block_;
        // We use < because for some cases we may have +1 more workgroups per dim.
        // Ie when k_tiles = 5, tiles_per_block = 3.
        finished_block_flags_.wait_lt(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset),
            workgroups_per_dim);
111
112
113
114
115
    }

    ///
    /// @brief      Wait until each workgroup has finished its work.
    ///
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index
    /// @param[in]  output_tile_idx_offset  The output tile index offset
    ///
    __device__ void
    WaitForReduction(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
    {
        // Wait untill the counter has been reset.
        finished_block_flags_.wait_eq(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), 0);
    }

    ///
    /// @brief      Reset flag counter to zero.
    ///
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index.
    /// @param[in]  output_tile_idx_offset  The output tile index offset.
    ///
    __device__ void Reset(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
    {
        finished_block_flags_.reset(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
    }

    ///
    /// @brief      Gets the flag value.
    ///
    /// @param[in]  k_tiles                 The number of tiles in the reduced dimension.
    /// @param[in]  output_tile_idx         The output (MN) tile index.
    /// @param[in]  output_tile_idx_offset  The output tile index offset.
147
    ///
148
149
150
    __device__ index_t GetFlagValue(index_t k_tiles,
                                    index_t output_tile_idx,
                                    index_t output_tile_idx_offset) const
151
    {
152
153
        return static_cast<index_t>(finished_block_flags_.ld(
            GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset)));
154
155
156
157
158
159
160
161
162
163
    }

    const index_t tile_count_;
    const index_t tiles_per_block_;
    index_t tile_id_;
    index_t block_tile_idx_;
    workgroup_barrier finished_block_flags_;
};

} // namespace ck