blkgemmpipe_scheduler.hpp 3.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"

namespace ck {

enum struct BlockGemmPipelineScheduler
{
    Intrawave,
    Interwave,
};

enum struct TailNumber
{
    // Single / Double buffer pipeline
    Odd,
    Even,

    // Long prefetch pipeline, up to 8
    One,
    Two,
    Three,
    Four,
    Five,
    Six,
    Seven,

    // Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
    Empty,
    // Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
    // prefetchstages
    Full,
};
template <index_t BlockSize,
          index_t MPerBlock,
          index_t NPerBlock,
          index_t KPerBlock,
          index_t ABufferLoadWidth,
          index_t BBufferLoadWidth,
          index_t ALDSWriteWidth,
          index_t BLDSWriteWidth,
          index_t ALDSReadWidth,
          index_t BLDSReadWidth,
          index_t MRepeat,
          index_t NRepeat,
          index_t MPerXDL,
          index_t NPerXDL,
          index_t KPerXDL>
struct BlockwiseGemmXdlops_pipeline_hotloop_inst
{
    static constexpr index_t WaveSize = 64;
    static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
    static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);

    static constexpr index_t A_LDS_Read_Width = ALDSReadWidth;
    static constexpr index_t B_LDS_Read_Width = BLDSReadWidth;

    static constexpr index_t A_Buffer_Load_Inst_Num =
        MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
    static constexpr index_t B_Buffer_Load_Inst_Num =
        NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);

    static constexpr index_t A_LDS_Write_Inst_Num =
        MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
    static constexpr index_t B_LDS_Write_Inst_Num =
        NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);

    static constexpr index_t A_LDS_Read_Inst_Num =
        WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
    static constexpr index_t B_LDS_Read_Inst_Num =
        WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);

    static constexpr index_t C_MFMA_Inst_Num =
        MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);

    static constexpr auto Print()
    {
        printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
               BlockSize,
               WaveSize,
               MPerBlock,
               NPerBlock,
               KPerBlock,
               MPerXDL,
               NPerXDL,
               KPerXDL);

        printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
93
94
95
               "%d, %d\n C MFMA inst: %d\n"
               "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
               "%d/ %d\n",
96
97
98
99
100
101
               A_Buffer_Load_Inst_Num,
               B_Buffer_Load_Inst_Num,
               A_LDS_Write_Inst_Num,
               B_LDS_Write_Inst_Num,
               A_LDS_Read_Inst_Num,
               B_LDS_Read_Inst_Num,
102
103
104
105
106
107
108
               C_MFMA_Inst_Num,
               A_LDS_Read_Width,
               B_LDS_Read_Width,
               ALDSWriteWidth,
               BLDSWriteWidth,
               ABufferLoadWidth,
               BBufferLoadWidth);
109
110
111
112
    }
};

} // namespace ck