gemm.hpp 1.95 KB
Newer Older
Paul Fultz II's avatar
Paul Fultz II committed
1
// SPDX-License-Identifier: MIT
arai713's avatar
arai713 committed
2
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
Paul Fultz II's avatar
Paul Fultz II committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

#pragma once

#include <string>

namespace ck {
namespace host {
namespace operation {

struct TileDesc
{
    int block_size               = 0;
    int m_per_block              = 0;
    int n_per_block              = 0;
    int k_per_block              = 0;
    int ak1                      = 0;
    int bk1                      = 0;
    int m_per_XDL                = 0;
    int n_per_XDL                = 0;
    int m_Xdl_per_wave           = 0;
    int n_Xdl_per_wave           = 0;
    int num_gemmk_prefetch_stage = 0;
};
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

struct TileDescGemmGemm
{
    int block_size               = 0;
    int gemm01_m_per_block       = 0;
    int gemm0_n_per_block        = 0;
    int gemm0_k_per_block        = 0;
    int gemm1_n_per_block        = 0;
    int gemm1_k_per_block        = 0;
    int ak1                      = 0;
    int bk1                      = 0;
    int b1k1                     = 0;
    int m_per_XDL                = 0;
    int n_per_XDL                = 0;
    int gemm0_m_Xdl_per_wave     = 0;
    int gemm0_n_Xdl_per_wave     = 0;
    int gemm1_n_Xdl_per_wave     = 0;
    int num_gemmk_prefetch_stage = 0;
};

Paul Fultz II's avatar
Paul Fultz II committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
struct BlockTransferDesc
{
    std::string thread_cluster_length        = "";
    std::string thread_cluster_arrange_order = "";
    std::string src_access_order             = "";
    int src_vec_dim                          = 0;
    int src_scalar_per_vector                = 0;
    int dst_scalar_per_vector_k1             = 0;
    int lds_add_extra_dim                    = 0;
};
struct CShuffleDesc
{
    int m_Xdl_per_wave_per_shuffle = 0;
    int n_Xdl_per_wave_per_shuffle = 0;
};
struct CBlockTransferDesc
{
    std::string cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl = "";
    int scalar_per_vector_n_wave_n_per_Xdl                                        = 0;
};

} // namespace operation
} // namespace host
} // namespace ck