prefetch.h 13.5 KB
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#pragma once
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "utils.h"
#include "static_switch.h"
#include "numeric_types.h"
#include "intrinsic_mls_ds.h"
template<int K, int BLOCK_M, int BLOCK_K, int WARP_M,  typename Element, typename ElementAccum, bool Is_even_MN>
inline __device__ void  prefetch_to_vgpr(
        vec4_uint k_ptr,
        Element* k_lds,
        union_vec2_f16x2<Element> k_reg[(K/BLOCK_K)*((WARP_M*BLOCK_K)/(32*32))*2][2],
        int max_seq_k_offset,
        int row_stride) {
    const int WARP_NUM = (BLOCK_M)/(WARP_M);
    const int k_lds_load_num = (BLOCK_M * BLOCK_K) / (4*32);
    const int K_LOAD_REQUESTS = k_lds_load_num / WARP_NUM;

    int warp_id =0;
    int warp_id_vec = threadIdx.x / 64; //warp id in a block

    warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);


    int k_warp_m_id         = (warp_id & ((BLOCK_M/WARP_M) - 1));
    int lane_id             = threadIdx.x & 63; //lane id, 0-63
    int k_lane_m_idx        = ((lane_id >> 4) & 1)*2 + ((lane_id >> 4) >> 1); //(0, 1, 2, 3) --> (0, 2, 1, 3)
    int k_lane_head_dim_idx = lane_id & 15;

    // int lds_offset = row * 8 + col * 32;
    int stage_id = 0;
    
    // MLS
    vec4_uint k_srsrc;
    k_srsrc[2] = row_stride;  // stride
    k_srsrc[3] = 0;

    #pragma unroll
    for(int k_loop = 0; k_loop<K/BLOCK_K; k_loop++) {
        {
            __builtin_amdgcn_sched_barrier(0);
            __builtin_amdgcn_s_waitcnt(0);
            __syncthreads();
            __builtin_amdgcn_sched_barrier(0);
            //global->lds, left matrix
            int q_block_buffer_load_global_offset = k_loop * BLOCK_K ;//+ block_id_m * BLOCK_M * K;
            // k_ptr buffer load mini size is 4*32, (BLOCK_M * BLOCK_K) mini size is (32*32)
            int k_lds_stage_offset = stage_id * (BLOCK_M/32) * (BLOCK_K/32)*(32*34);
            for(int load = 0,warp_loop = warp_id; load < K_LOAD_REQUESTS; warp_loop += WARP_NUM, ++load) {
                int padding = (warp_loop & 7)*2; // padding size in shared memory per buffer load, to avoid bank conflict
                int k_warp_buffer_load_m_id = (warp_loop & (BLOCK_M/4 - 1)); //这样子对L1和utlc1有啥影响呢?
                    // int q_warp_buffer_load_k_id = (warp_loop / (BLOCK_M/4));
                int q_warp_buffer_load_lds_offset     =  k_lds_stage_offset/* + (q_warp_buffer_load_k_id * BLOCK_M * 34)*/ + ((k_warp_buffer_load_m_id >> 3)*(32*34) + (k_warp_buffer_load_m_id & 7)*(4*32));
                // int q_warp_buffer_load_global_offset  =  (q_warp_buffer_load_k_id * 32);

                int gvOffset_s = (q_block_buffer_load_global_offset/* + q_warp_buffer_load_global_offset*/) / 2;
                int gvOffset_v;
                if constexpr (not Is_even_MN) {
                    gvOffset_v = ((min(k_warp_buffer_load_m_id * 4 + k_lane_m_idx, max_seq_k_offset - 1)) * row_stride) / 2 + k_lane_head_dim_idx;
                } else {
                    gvOffset_v = ((k_warp_buffer_load_m_id * 4 + k_lane_m_idx) * row_stride) / 2 + k_lane_head_dim_idx;
                }
                int lds_offset = (q_warp_buffer_load_lds_offset + padding) / 2; // +  lane_id;

                builtin_buffer_load_dword_lds_bypass_glc_slc(k_lds, k_ptr, lds_offset, gvOffset_s, gvOffset_v);
            }

            __builtin_amdgcn_sched_barrier(0);
            __builtin_amdgcn_s_waitcnt(0);
            __syncthreads();
            __builtin_amdgcn_sched_barrier(0);
                    
            // k_lds_stage_offset = stage_id * (BLOCK_M/32) * (BLOCK_K/32)*(32*17);

            vec2_Element<Element> *k_lds_v2fp16 = (vec2_Element<Element> *)(k_lds);
            ds_read_tile_pad(WARP_M, BLOCK_K, WARP_NUM, Element, k_lds_v2fp16, k_lds_stage_offset, k_reg, k_loop, warp_id, lane_id);
        }
    }
}

//matrix_load单位:32 * 32
//ds_read_matrix单位:32 * 16
//M = 128, N = 128
template<bool trans, int M, int N,  typename Element, typename ElementAccum, bool Is_even_MN>
inline __device__ void  prefetch_to_vgpr_gfx938(
        vec4_uint ptr,
        Element* lds,
        union_vec4_f16x2<Element> reg[M * N / (64 * 8)],//vec4_fp16x2有8个element,64个线程
        int max_column_offset,
        int warp_id) {
    constexpr int ELEMENT_BYTES   = sizeof(Element);
    const int stages = 2;
    const int WARP_NUM = 4;
    int row_stride = ptr[2];
    vec4_uint srsrc;
    srsrc[2] = row_stride;
    srsrc[3] = 0;

    //计算LDS地址,每个warp使用一个32*32
    int lds_offset = (warp_id * 32 * 32);
    size_t lds_load_offset = reinterpret_cast<size_t>(lds) + lds_offset * ELEMENT_BYTES;
    
    int stages_id = 0;
    if(stages == 2) {
        int m_loop = 0;
        int n_loop = 0;
        int global_offset = (warp_id * row_stride * 32 + n_loop * 32);
        int lds_offset_stage = (lds_offset + stages_id * (WARP_NUM * 32 * 32)) * ELEMENT_BYTES;
        if constexpr (!Is_even_MN) {
            //对M方向进行边界判断,看需要pad多少0
            int nm_filter_max = (m_loop * 128 + (warp_id + 1) * 32) - max_column_offset;
            int nm_filter = max(0, (m_loop * 128 + (warp_id + 1) * 32) - max_column_offset);
            if(nm_filter_max >= 32) {
                global_offset = (0 * row_stride * 32 + n_loop * 32);
                nm_filter = max(0, (m_loop * 128 + 0 * 32) - max_column_offset);
            }
            srsrc[3] = nm_filter << 8; // set only once
        }
        *(uint64_t*)&srsrc = VA_LIMIT_BITS(*(uint64_t*)&ptr + global_offset * ELEMENT_BYTES);
        if(trans) {
121
            inline_matrix_load_32x32_b16_lds_trans<0, 0>(lds, srsrc, lds_offset_stage, 0);
zhangshao's avatar
zhangshao committed
122
        } else {
123
            inline_matrix_load_32x32_b16_lds<0, 0>(lds, srsrc, lds_offset_stage, 0);
zhangshao's avatar
zhangshao committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        }
    }
    for(int m_loop = 0; m_loop < M / 128; ++m_loop) {
        for(int n_loop = stages - 1; n_loop < N / 32 + stages - 1; ++n_loop) {
            if(stages == 2) {
                stages_id ^= 1;
            }
            //更新global地址
            int global_offset = (warp_id * row_stride * 32 + n_loop * 32);
            int lds_offset_stage = (lds_offset + stages_id * (WARP_NUM * 32 * 32)) * ELEMENT_BYTES;
            // size_t lds_load_offset_stage = reinterpret_cast<size_t>(lds) + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) * ELEMENT_BYTES + lds_offset * ELEMENT_BYTES;
            if constexpr (!Is_even_MN) {
                //对M方向进行边界判断,看需要pad多少0
                int nm_filter_max = (m_loop * 128 + (warp_id + 1) * 32) - max_column_offset;
                int nm_filter = max(0, (m_loop * 128 + (warp_id + 1) * 32) - max_column_offset);
                if(nm_filter_max >= 32) {
                    global_offset = (0 * row_stride * 32 + n_loop * 32);
                    nm_filter = max(0, (m_loop * 128 + 0 * 32) - max_column_offset);
                }
                srsrc[3] = nm_filter << 8; // set only once
            }
            *(uint64_t*)&srsrc = VA_LIMIT_BITS(*(uint64_t*)&ptr + global_offset * ELEMENT_BYTES);
            if(n_loop < N / 32) {
                if(trans) {
148
                    inline_matrix_load_32x32_b16_lds_trans<0, 0>(lds, srsrc, lds_offset_stage, 0);
zhangshao's avatar
zhangshao committed
149
                } else {
150
                    inline_matrix_load_32x32_b16_lds<0, 0>(lds, srsrc, lds_offset_stage, 0);
zhangshao's avatar
zhangshao committed
151
152
153
154
155
156
157
158
159
160
161
162
163
                }
            }
        
            if(stages == 2 && n_loop < N /32) {
                vmcnt_wait_nosync(1);
            } else {
                vmcnt_wait_nosync(0);
            }
            // __builtin_amdgcn_s_waitcnt(0);
            // __syncthreads();
            if(trans){
                // DS_READ_MATRIX_32X32_B16(ds_offset_cast(lds_load_offset_stage), reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16, reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16, true);
                if constexpr (std::is_same_v<Element, half_t>) {
164
165
                    reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =  __builtin_hcu_ds_read_matrix_trans_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
                    reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =  __builtin_hcu_ds_read_matrix_trans_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
zhangshao's avatar
zhangshao committed
166
                } else {
167
168
                    reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =  __builtin_hcu_ds_read_matrix_trans_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
                    reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =  __builtin_hcu_ds_read_matrix_trans_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
zhangshao's avatar
zhangshao committed
169
170
171
172
                }
            } else {
                // DS_READ_MATRIX_32X32_B16(ds_offset_cast(lds_load_offset_stage), reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16, reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16, false);
                if constexpr (std::is_same_v<Element, half_t>) {
173
174
                    reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =  __builtin_hcu_ds_read_matrix_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
                    reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =  __builtin_hcu_ds_read_matrix_format_f16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
zhangshao's avatar
zhangshao committed
175
                } else {
176
177
                    reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16x8 =  __builtin_hcu_ds_read_matrix_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 0, 2, 1, 0);
                    reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16x8 =  __builtin_hcu_ds_read_matrix_format_bf16(lds + (stages == 2 ? (stages_id ^ 1) : stages_id) * (WARP_NUM * 32 * 32) + lds_offset, 1024, 2, 1, 0);
zhangshao's avatar
zhangshao committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
                }
            }
            lgkmcnt_wait(0);
            // __builtin_amdgcn_s_waitcnt(0);
            // __syncthreads();
        }
    }
}

//matrix_load单位:32 * 32
//ds_read_matrix单位:32 * 16
//M = 32, N = 128
template<bool trans, int M, int N,  typename Element, typename ElementAccum, bool Is_even_MN, int WARP_NUM = 4>
inline __device__ void  prefetch_to_lds_gfx938(
        vec4_uint ptr,
        int global_start_offset,
        Element* lds,
        int max_column_offset,
        int warp_id) {
    const int ELEMENT_BYTES   = sizeof(Element);
    const int LOAD_NUM = M * N / (32 * 32);
    int row_stride = ptr[2];
    vec4_uint srsrc;
    srsrc[2] = row_stride;
    srsrc[3] = 0;
    // __builtin_amdgcn_s_waitcnt(0);
    // __syncthreads();
    //直接拉通M * N,看有多少个 32*32 的矩阵需要load
    for(int loop = 0; loop < (LOAD_NUM + WARP_NUM - 1) / WARP_NUM; loop++) {
        int loop_warp = loop * WARP_NUM + warp_id;
        if (loop_warp < LOAD_NUM) {
            int m_loop = loop_warp / (N / 32);
            int n_loop = loop_warp % (N / 32);
            //更新global地址
            int global_offset = (global_start_offset + m_loop * row_stride + n_loop * 32) * ELEMENT_BYTES;
            if constexpr (!Is_even_MN) {
                //对M方向进行边界判断,看需要pad多少0
                int nm_filter_max = (m_loop + 1) * 32 - max_column_offset;
                int nm_filter = nm_filter_max;
                if(nm_filter_max >= 32) {
                    global_offset = (global_start_offset + 0 * row_stride + n_loop * 32) * ELEMENT_BYTES;
                    nm_filter = (0 + 1) * 32 - max_column_offset;
                }
                nm_filter = max(0, nm_filter);
                srsrc[3] = nm_filter << 8; // set only once
            }
            *(uint64_t*)&srsrc = VA_LIMIT_BITS(*(uint64_t*)&ptr + global_offset);
            //计算LDS地址,每个warp使用一个32*32;下一个loop重复利用
            int lds_offset = (loop_warp * 32 * 32) * ELEMENT_BYTES;
227
            int lds_load_offset = reinterpret_cast<size_t>(lds) + lds_offset;
zhangshao's avatar
zhangshao committed
228
            if (trans) {
229
                inline_matrix_load_32x32_b16_lds_trans<0, 0>(lds, srsrc, lds_offset, 0);
zhangshao's avatar
zhangshao committed
230
            } else {
231
                inline_matrix_load_32x32_b16_lds<0, 0>(lds, srsrc, lds_offset, 0);
zhangshao's avatar
zhangshao committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
            }
        }
    }
    // __builtin_amdgcn_s_waitcnt(0);
    // __syncthreads();
}

template<bool Is_even_MN, int K/*head_dim*/, int BLOCK_M, int BLOCK_N, int BLOCK_K, int WARP_M, int WARP_N, typename Element>
__forceinline__ __device__ void  prefetch_to_tmp_lds_wait(vec4_uint B_ptr, Element* B_lds, int max_n_len_offset, int warp_id, int row_stride)
{
    const int WARP_NUM = BLOCK_M/WARP_M;
    int lane_id = threadIdx.x & 63; //lane id, 0-63
    for(int n_loop = 0 ; n_loop < BLOCK_N/WARP_N; n_loop++){
        for(int k_loop = 0; k_loop < K/BLOCK_K; k_loop++) {
            const int lgkmcnt = (BLOCK_N/WARP_N * K/BLOCK_K - 1) - (n_loop * K/BLOCK_K + k_loop);
            lgkmcnt_wait(lgkmcnt);
            int B_block_buffer_load_global_offset = k_loop * BLOCK_K + n_loop * WARP_N * K;
            // headdim=256时的LDS用量为 256/32 * 32 * 34 * 2byte= 17 KB,如果同时读Q和dO到LDS,就会超过32KB
            // headdim=224时的LDS用量为 224/32 * 32 * 34 * 2byte= 14.875 KB,如果同时读Q和dO到LDS,不会超32KB
            int B_lds_stage_offset = k_loop * (WARP_N/32) * (BLOCK_K/32)*(32*34) + n_loop * (K/32) * (WARP_N/32)*(32*34);
            buffer_load_lds_tile_pad(Is_even_MN, WARP_NUM, row_stride, WARP_N, BLOCK_K, Element, B_ptr, B_lds, B_block_buffer_load_global_offset, B_lds_stage_offset, max_n_len_offset, warp_id, lane_id);
        }
    }
255
}