gemm.cpp 7.38 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
template <typename ADataType,
          typename BDataType,
          typename DsDataType,
          typename EDataType,
          typename AElementwiseOperation,
          typename BElementwiseOperation,
          typename CDEElementwiseOperation>
struct GemmMultiD
{
    static constexpr index_t NumDTensor = DsDataType::Size();

    __host__ __device__ void
    operator()(TileProgram& tp,
               const std::array<index_t, 2> a_m_k_lengths,
               const std::array<index_t, 2> a_m_k_strides,
               const std::array<index_t, 2> b_n_k_lengths,
               const std::array<index_t, 2> b_n_k_strides,
               const std::array<const std::array<index_t, 2>, NumDTensor> ds_m_n_lengths,
               const std::array<const std::array<index_t, 2>, NumDTensor> ds_m_n_strides,
               const std::array<index_t, 2> e_m_n_lengths,
               const std::array<index_t, 2> e_m_n_strides,
               //
               const T* p_a,
               const T* p_b,
               const std::array<const T*> p_ds,
               T* p_e)
    {
        using namespace ck;

Chao Liu's avatar
Chao Liu committed
30
        const auto a  = tp(make_naive_tensor(a_m_k_lengths, a_m_k_strides), p_a);
Chao Liu's avatar
Chao Liu committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        const auto b  = tp(make_naive_tensor(b_n_k_lengths, b_n_k_strides), p_b);
        const auto ds = tp(generate_tuple(
            [&](auto i) {
                return make_naive_tensor(ds_m_n_lengths[i], ds_m_n_strides[i], p_ds[i]),
            },
            Number<NumDTensor>{}));
        auto e        = tp(make_naive_tensor(e_m_n_lengths, e_m_n_strides), p_e);

        // divide problem
        const auto num_m = e_m_n_lengths[0];
        const auto num_n = e_m_n_lengths[1];

        const auto id_block = get_block_1d_id();

        const auto num_tile_m = num_gemmm / MPerTile;
        const auto num_tile_n = num_gemmn / NPerTile;

        const auto block2tile = tp(make_cluster_descriptor(make_tuple(num_tile_m, num_tile_n)));

        const auto id_tile = block2tile.CalculateBottonIndex(id_block);

        const auto id_tile_m = id_tile.At<0>();
        const auto id_tile_n = id_tile.At<1>();

        // A/B in DRAM
        // A/B DRAM layout is part of problem, not solution
#if 1
        // DO NOT let user know there is optimization on tensor transform on A/B DRAM tensor
        const auto a_dram_global = tp(make_naive_tensor(a_m_k_lengths, a_m_k_strides), p_a_dram);
        const auto b_dram_global = tp(make_naive_tensor(b_n_k_lengths, b_n_k_strides), p_b_dram);
#endif

        // A/B tile in LDS
        // A/B DRAM layout is part of solution
        ADataType* p_a_lds = shared_memmory.get_pointer(0);

        // [allow optimization] allow different LDS layouts
        constexpr auto a_lds_block =
            make_tensor(p_a_lds, {kMPerBlock, kKPerBlock}, a_lds_block_strategy);

        constexpr auto a_lds_byte = a_lds_block.get_num_of_byte();

        BDataType* p_b_lds = shared_memory.get_aligned_pointer(a_lds_byte);

        // [allow optimization] allow different LDS layouts
        constexpr auto b_lds_block =
            make_tensor({p_b_lds, kNPerBlock, kKPerBlock}, b_lds_block_strategy);

        // A/B copy
#if 0
        auto a_block_copy = make_copier(a_dram_global,
                                        a_lds_block,
                                        make_tuple(kMPerBlock, kKPerBlock),
                                        make_tuple(id_tile_m * kMPerBlock, 0),
                                        a_block_copy_strategy);

        auto b_block_copy = make_copier(b_dram_global,
                                        b_lds_block,
                                        make_tuple(kNPerBlock, kKPerBlock),
                                        make_tuple(id_tile_n * kNPerBlock, 0),
                                        b_block_copy_strategy);
#else
        auto window_a_dram = make_window(a_dram_global,
                                         {MPerTile, KPerTile},
                                         {id_tile_m * MPerTile, id_tile_k * KPerTile},
                                         a_dram_window_map_strategy);

        auto window_a_block =
Chao Liu's avatar
Chao Liu committed
99
            make_window(a_lds_block, {MPerTile, KPerTile}, {0, 0}, a_lds_window_map_strategy);
Chao Liu's avatar
Chao Liu committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

#endif

#if 1
        // block GEMM
        // operation-based syntax: per-operation solution strategy
        auto block_gemm = make_block_gemm(a_lds_block, b_lds_block, block_gemm_strategy);
#endif

        // Distributed C in VGPR
#if 1
        // C layout is decided alone
        // C should be distributed,
        auto c_vgpr_block =
            make_distributed_tensor({kMPerBlock, kNPerBlock}, c_vgpr_block_strategy);
#elif 0
        // C layout is decided by block GEMM
        auto c_vgpr_block = block_gemm.get_c_vgpr_block();
#endif

        for(index_t k = 0; k < K; k += kKPerBlock)
        {
            auto a_vgpr_block_tmp = load(window_a_dram, a_dram_load_strategy);
            auto b_vgpr_block_tmp = load(window_b_dram, b_dram_load_strategy);

            auto a_vpgr_block = elementwise_op(a_vgpr_block_tmp, a_element_op);
            auto b_vpgr_block = elementwise_op(b_vgpr_block_tmp, b_element_op);

            copy(a_vgpr_block, a_lds_block, a_lds_store_strategy);
            copy(b_vgpr_block, b_lds_block, b_lds_store_strategy);

            block_sync_lds();

            dot_product_accumulate(c_vgpr_block, a_lds_block, b_lds_block);

            block_sync_lds();

            window_a_dram += {0, kKPerBlock};
            window_b_dram += {0, kKPerBlock};
        }

        auto p_c_lds = xxx;

        auto c_lds = make_tensor(p_c_lds, xxxxxx);

        auto window_c_vgpr =
            make_window(c_vgpr, {kMPerShuffle, kNPerShuffle}, {0, 0}, c_vgpr_window_strategy);

        auto window_c_lds =
            make_window(c_lds, {kMPerShuffle, kNPerShuffle}, {0, 0}, c_lds_window_strategy);

        auto window_d_dram = make_window(d_dram_global,
                                         {kMPerShuffle, kNPerShuffle},
                                         {id_tile_m * kMPerTile, id_tile_n * kNPerTile},
                                         d_dram_window_strategy);

        auto window_e_dram = make_window(e_dram_global,
                                         {kMPerShuffle, kNPerShuffle},
                                         {id_tile_m * kMPerTile, id_tile_n * kNPerTile},
                                         e_dram_window_strategy);

        for(m = 0; m < kMPerBlock; m += kMPerShuffle)
        {
            for(n = 0; n < kNPerBlock; n += kNPerShuffle)
            {
                // write C into LDS for shuffle
                copy(window_c_vgpr, window_c_lds, c_lds_store_strategy);

                // load C from LDS to complete shuffle
                auto c_vgpr_slice_shuffled = load(window_c_lds, c_lds_load_strategy);

                // load D from dram
                auto d_vgpr_block_slice = load(window_d_dram, d_dram_load_strategy);

                // element wise op
                // [Question] need to gurantee it always function
                //   1. C/D should have same layout, how to gurantee?
                //   2. if C/D have different layout, then need to do shuffle
                //   3. if C/D have different layout, what should E layout be?
                auto e_vgpr_block_slice =
                    elementwise_op(c_vgpr_block_slice, d_vgpr_block_slice, cd_elementwise_op);

                // write E into dram
                copy(e_vgpr_block_slice, window_e_dram, e_dram_store_strategy);
            }
        }
    }
};