gemm.hpp 7.23 KB
Newer Older
aska-0096's avatar
aska-0096 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"

#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/grid/grid_gemm_problem.hpp"
#include "ck/tile_program/grid/grid_gemm_v1.hpp"
#include "ck/tile_program/grid/grid_gemm_v1_default_policy.hpp"

// C = A * B
template <typename ADataType,
          typename BDataType,
          typename AccDataType,
          typename CDataType,
          typename ALayout,
          typename BLayout,
          typename CLayout,
          typename AElementFunction,
          typename BElementFunction,
          typename CElementFunction,
          ck::index_t kAAlignment,
          ck::index_t kBAlignment,
          ck::index_t kCAlignment,
          ck::index_t kBlockSize_,
          ck::index_t kMPerBlock_,
          ck::index_t kNPerBlock_,
          ck::index_t kKPerBlock_>
struct Gemm
{
    using GridGemmProblem = ck::tile_program::grid::GridGemmProblem<ADataType,
                                                                    BDataType,
                                                                    AccDataType,
                                                                    CDataType,
                                                                    AElementFunction,
                                                                    BElementFunction,
                                                                    CElementFunction>;

    struct GridGemmPolicy
    {
        static constexpr ck::index_t kBlockSize = kBlockSize_;
        static constexpr ck::index_t kMPerBlock = kMPerBlock_;
        static constexpr ck::index_t kNPerBlock = kNPerBlock_;
        static constexpr ck::index_t kKPerBlock = kKPerBlock_;

        template <typename Problem>
        __host__ __device__ static constexpr auto MakeBlock2TileMap(ck::index_t NumTilesM,
                                                                    ck::index_t NumTilesN)
        {
            using namespace ck;

            const auto unmerge = make_merge_transform(make_tuple(NumTilesN, NumTilesM));

            return [unmerge](index_t block_id) {
                MultiIndex<2> unmerged;
                unmerge.CalculateLowerIndex(unmerged, make_multi_index(block_id));

                return make_multi_index(unmerged.At<1>(), unmerged.At<0>());
            };
        }

        template <typename Problem>
        __host__ __device__ static constexpr auto GetBlockGemmPipeline()
        {
            using namespace ck;
            using namespace ck::tile_program;
            using namespace ck::tile_program::block;

            using BlockGemmPipelineProblem_ =
                BlockGemmPipelineProblem<ADataType,
                                         BDataType,
                                         AccDataType,
                                         kBlockSize,
                                         TileGemmShape<kMPerBlock, kNPerBlock, kKPerBlock>>;

            return BlockGemmPipelineAGmemBGmemCRegV2<
                BlockGemmPipelineProblem_,
                BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy>{};
        }
    };

    using GridGemm = ck::GridGemmV1<GridGemmProblem, GridGemmPolicy>;

    __device__ void operator()(const ADataType* p_a,
                               const BDataType* p_b,
                               CDataType* p_c,
                               const ck::index_t M,
                               const ck::index_t N,
                               const ck::index_t K,
                               const ck::index_t Lda,
                               const ck::index_t Ldb,
                               const ck::index_t Ldc,
                               const AElementFunction& a_element_func,
                               const BElementFunction& b_element_func,
                               const CElementFunction& c_element_func) const
    {
        using namespace ck;
        using namespace ck::tile_program;
        using namespace ck::tile_program::block;

        const auto a_dram = [&] {
            if constexpr(is_same_v<ALayout, ck::tensor_layout::gemm::RowMajor>)
            {
                return make_naive_tensor_view<AddressSpaceEnum::Global>(
                    p_a, make_tuple(M, K), make_tuple(Lda, 1), Number<kAAlignment>{}, Number<1>{});
            }
            else
            {
                const auto a_k_m_desc = make_naive_tensor_view<AddressSpaceEnum::Global>(
                    p_a, make_tuple(K, M), make_tuple(Lda, 1), Number<kAAlignment>{}, Number<1>{});

                return transform_tensor_view(
                    a_k_m_desc,
                    make_tuple(make_pass_through_transform(M), make_pass_through_transform(K)),
                    make_tuple(Sequence<1>{}, Sequence<0>{}),
                    make_tuple(Sequence<0>{}, Sequence<1>{}));
            }
        }();

        const auto b_dram = [&] {
            if constexpr(is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>)
            {
                return make_naive_tensor_view<AddressSpaceEnum::Global>(
                    p_b, make_tuple(N, K), make_tuple(Ldb, 1), Number<kBAlignment>{}, Number<1>{});
            }
            else
            {
                const auto b_k_n_desc = make_naive_tensor_view<AddressSpaceEnum::Global>(
                    p_b, make_tuple(K, N), make_tuple(Ldb, 1), Number<kBAlignment>{}, Number<1>{});

                return transform_tensor_view(
                    b_k_n_desc,
                    make_tuple(make_pass_through_transform(N), make_pass_through_transform(K)),
                    make_tuple(Sequence<1>{}, Sequence<0>{}),
                    make_tuple(Sequence<0>{}, Sequence<1>{}));
            }
        }();

        const auto c_dram = [&] {
            if constexpr(is_same_v<CLayout, ck::tensor_layout::gemm::RowMajor>)
            {
                return make_naive_tensor_view<AddressSpaceEnum::Global>(
                    p_c, make_tuple(M, N), make_tuple(Ldc, 1), Number<kCAlignment>{}, Number<1>{});
            }
            else
            {
                const auto c_n_m_desc = make_naive_tensor_view<AddressSpaceEnum::Global>(
                    p_c, make_tuple(N, M), make_tuple(Ldc, 1), Number<kCAlignment>{}, Number<1>{});

                return transform_tensor_view(
                    c_n_m_desc,
                    make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
                    make_tuple(Sequence<1>{}, Sequence<0>{}),
                    make_tuple(Sequence<0>{}, Sequence<1>{}));
            }
        }();

        GridGemm{}(a_dram, b_dram, c_dram, a_element_func, b_element_func, c_element_func);
    }
};