work_scheduling.hpp 4.91 KB
Newer Older
1
// SPDX-License-Identifier: MIT
Adam Osewski's avatar
Adam Osewski committed
2
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

#pragma once

#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/workgroup_barrier.hpp"

namespace ck {

enum struct WorkSchedulingPolicy
{
    StridedTileLoop
};

///
/// @brief      This class describes a strided reduction tile loop work scheduling.
///
///
/// @par Overview
///     This work scheduling policy assume linear mapping (with stride) of workgroups along
///     the reduced dimension. In GEMM problem this mean that consecutive workgroups are mapped
///     to strided data tiles along K dimension. This can be obtained using i.e.
///     @see BlockToCTileMap_ReduceKSplit.
///
/// @par Synchronization
///     All workgroups aligned along particular reduced dimension have to reduce their partial
///     results. In order to do that there's a need to use global flags and atomics to communicate
///     between those workgroups.
///
class StridedReductionTileLoop
{
    public:
35
    __device__ StridedReductionTileLoop(index_t tile_count, uint32_t* const __restrict__ p_flags)
36
37
38
39
        : tile_count_{tile_count},
          tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()},
          tile_id_{get_block_1d_id() * tiles_per_block_},
          block_tile_idx_{0},
40
          finished_block_flags_{p_flags}
41
42
43
    {
    }

44
45
46
47
48
    __device__ bool HasTile() const
    {
        return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_;
    }

49
50
51
52
    __device__ bool GetNextTile()
    {
        tile_id_++;
        block_tile_idx_++;
Adam Osewski's avatar
Adam Osewski committed
53
        return HasTile();
54
55
    }

56
    __device__ index_t GetFlagCount() const { return get_grid_size(); }
57

58
    ///
59
    /// @brief      Get this workgroup flag index.
60
61
62
    ///
    /// @return     The workgroup flag index.
    ///
63
    __device__ uint32_t GetWorkgroupFlagIdx() const { return static_cast<uint32_t>(blockIdx.x); }
64
65
66
67

    ///
    /// @brief      Flag each workgroup that has finished its work.
    ///
68
    __device__ void FlagFinished() { finished_block_flags_.inc(GetWorkgroupFlagIdx()); }
69
70
71
72

    ///
    /// @brief      Wait until each workgroup has finished its work.
    ///
73
74
    /// @param[in]  k_tiles     The number of tiles in the reduced dimension.
    /// @param[in]  k_tile_idx  The currently processed tile k index.
75
    ///
76
77
    /// @return     The number of neighbours.
    ///
78
    __device__ index_t WaitForNeighbours(index_t k_tiles, index_t k_tile_idx)
79
    {
80
81
82
83
84
        // We have to wait for all workgroups to finish their partial results.
        // First count how many "neighbour" workgroups we have to check.
        index_t neighbour_count = 0;
        if(tiles_per_block_ < k_tiles)
        {
Adam Osewski's avatar
Adam Osewski committed
85
            // Since we can have deviation (+/-1) in neighbours number
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
            // we calculate how many workgroups are needed to process the k-tiles left.
            neighbour_count = (k_tiles - k_tile_idx - 1 + tiles_per_block_ - 1) / tiles_per_block_;
        }
        // If we have more tiles to process than the reduction dimension size,
        // then the number of neighbours depends on first K-tile workgroup block tile idx.
        else
        {
            if(block_tile_idx_ == tiles_per_block_)
            {
                // If we just finished work per workgroup then check at which k-idx we are.
                neighbour_count = (k_tile_idx < k_tiles - 1) ? 1 : 0;
            }
            else
            {
                // If we have still tiles to process then it means that we already processed
                // whole K-dim.
                neighbour_count = 0;
            }
        }

        if(neighbour_count > 0)
        {
108
109
110
111
112
113
114
115
116
            index_t flag_sum = 0;
            do
            {
                flag_sum = 0;
                for(index_t i = 1; i <= neighbour_count; ++i)
                {
                    flag_sum += finished_block_flags_.ld(GetWorkgroupFlagIdx() + i);
                }
            } while(flag_sum != neighbour_count);
117
118
119
        }

        return neighbour_count;
120
121
122
    }

    ///
123
    /// @brief      Wait until reduction workgroup has finished its work.
124
    ///
125
    __device__ void WaitForReduction()
126
    {
127
128
        // Wait untill my counter has been reset.
        finished_block_flags_.wait_eq(GetWorkgroupFlagIdx(), 0);
129
130
131
132
133
    }

    ///
    /// @brief      Reset flag counter to zero.
    ///
134
    /// @param[in]  neighbour_count     The number of peer workgroups.
135
    ///
136
    __device__ void Reset(index_t neighbour_count)
137
    {
138
139
140
141
        for(index_t i = 0; i <= neighbour_count; ++i)
        {
            finished_block_flags_.reset(GetWorkgroupFlagIdx() + i);
        }
142
143
144
145
146
    }

    ///
    /// @brief      Gets the flag value.
    ///
147
    __device__ uint32_t GetFlagValue() const
148
    {
149
        return finished_block_flags_.ld(GetWorkgroupFlagIdx());
150
151
152
153
154
155
156
157
158
159
    }

    const index_t tile_count_;
    const index_t tiles_per_block_;
    index_t tile_id_;
    index_t block_tile_idx_;
    workgroup_barrier finished_block_flags_;
};

} // namespace ck