gemm_awq.cu 70.6 KB
Newer Older
fengzch-das's avatar
fengzch-das committed
1
2
#include <cuda_fp16.h>
#include <cuda_bf16.h>
3
#include "semaphore.h"
muyangli's avatar
muyangli committed
4
#include "gemm_awq.h"
Muyang Li's avatar
Muyang Li committed
5
// #include "../../../nunchaku/csrc/quantization/dequantize.cuh"
Samuel Tesfai's avatar
Samuel Tesfai committed
6
7
8
#include "dequantize.cuh"
#include <stdio.h>
#include "../dispatch_utils.h"
Muyang Li's avatar
Muyang Li committed
9
// #include "../../../nunchaku/csrc/utils.cuh"
Samuel Tesfai's avatar
Samuel Tesfai committed
10
11
#include "../utils.cuh"

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include <cuda_pipeline_primitives.h>

#define kInterleave 4
#define OP_M 16
#define OP_N 8
#define OP_K 16
#define INTRIN_M 16
#define INTRIN_N 16
#define INTRIN_K 16
#define WARP_SIZE 32
#define SMEM_PAD_A 0
#define SMEM_PAD_B 0
#define PACK_SIZE 8
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B"
#else
#define L2_CACHEHINT(size)
#endif

Muyang Li's avatar
Muyang Li committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#define KERNEL_LAUNCH_CODE                                                                                             \
    int num_mn_tiles        = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N;             \
    Tensor _semaphores      = Tensor::empty({num_mn_tiles}, Tensor::INT32, _in_feats.device());                        \
    auto semaphores         = reinterpret_cast<int *>(_semaphores.data_ptr<int>());                                    \
    constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K);                                  \
    constexpr int SCALES_SMEM_SIZE =                                                                                   \
        (G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2);                        \
    constexpr int kSmemByteSize =                                                                                      \
        (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + SCALES_SMEM_SIZE) * STAGES *      \
        sizeof(f16_t);                                                                                                 \
    if (kSmemByteSize >= 99 * 1024) {                                                                                  \
        printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);        \
        return _out_feats;                                                                                             \
    }                                                                                                                  \
    int j_factors1 = num_out_channels / CTA_N / 1;                                                                     \
    dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK);                                        \
    dim3 threads_per_block(WARP_SIZE, NUM_WARPS);                                                                      \
    auto kernel_func = gemm_w4a16_T1<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G, SPLITK>;           \
fengzch-das's avatar
fengzch-das committed
49
50
    cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);                     \
    kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(                                                     \
Muyang Li's avatar
Muyang Li committed
51
52
53
54
55
56
57
58
59
60
61
62
        in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels);

template<int N>
__inline__ __host__ __device__ int get_log_tile(int n) {
    if (N >= 8 && n >= 6)
        return 3;
    else if (N >= 4 && n >= 3)
        return 2;
    else if (N >= 2 && n >= 2)
        return 1;
    else
        return 0;
63
64
}

Muyang Li's avatar
Muyang Li committed
65
66
__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile) {
    return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
67
68
}

Muyang Li's avatar
Muyang Li committed
69
70
71
72
73
74
75
76
77
78
template<int SLICES, int NUM_WARPS_MN>
__device__ void sync_slice(int slice_id) {
    if constexpr (SLICES == 1) {
        __syncthreads();
    } else {
        constexpr int SLICE_GROUP      = (SLICES + 7) / 8;
        constexpr uint32_t num_threads = NUM_WARPS_MN * WARP_SIZE;
        const uint32_t barrier_id      = slice_id / SLICE_GROUP + 1;
        asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
    }
79
80
}

Muyang Li's avatar
Muyang Li committed
81
82
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr) {
    uint32_t smem_int_ptr;
83

Muyang Li's avatar
Muyang Li committed
84
85
86
    asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
        : "=r"(smem_int_ptr)
        : "l"(ptr));
87

Muyang Li's avatar
Muyang Li committed
88
    return smem_int_ptr;
89
90
}

Muyang Li's avatar
Muyang Li committed
91
92
template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
fengzch-das's avatar
fengzch-das committed
93
94
    static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
                  "ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types.");
Muyang Li's avatar
Muyang Li committed
95
96
97
98
99
100
101
    asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
                 "{%0, %1, %2, %3}, [%4];"
                 : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
                   "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
                   "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
                   "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
                 : "r"(addr));
102
103
}

Muyang Li's avatar
Muyang Li committed
104
105
template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
fengzch-das's avatar
fengzch-das committed
106
107
    static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
                  "ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types.");
Muyang Li's avatar
Muyang Li committed
108
109
110
111
112
113
114
    asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
                 "{%0, %1, %2, %3}, [%4];"
                 : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
                   "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
                   "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
                   "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
                 : "r"(addr));
115
116
}

Muyang Li's avatar
Muyang Li committed
117
118
119
120
121
122
123
124
125
126
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask) {
    const int cp_size = 16;
    asm volatile("{"
                 "  .reg .pred p;"
                 "  setp.ne.b32 p, %0, 0;"
                 "  @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
                                                                    "}" ::"r"((int)mask),
                 "r"(smem_int_ptr),
                 "l"(src),
                 "n"(cp_size));
127
128
}

Muyang Li's avatar
Muyang Li committed
129
template<typename f16_t>
130
131
__device__ __inline__ void mma_m16n8k16(float *C_warp, f16_t *A_shared_warp, f16_t *B_shared_warp);

Muyang Li's avatar
Muyang Li committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
template<>
__device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp, half *B_shared_warp) {
    asm volatile(
        "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
        "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
        : "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
        : "r"(((unsigned *)A_shared_warp)[0]),
          "r"(((unsigned *)A_shared_warp)[1]),
          "r"(((unsigned *)A_shared_warp)[2]),
          "r"(((unsigned *)A_shared_warp)[3]),
          "r"(((unsigned *)B_shared_warp)[0]),
          "r"(((unsigned *)B_shared_warp)[1]),
          "f"(((float *)C_warp)[0]),
          "f"(((float *)C_warp)[1]),
          "f"(((float *)C_warp)[2]),
          "f"(((float *)C_warp)[3]));
148
149
}

Muyang Li's avatar
Muyang Li committed
150
151
template<>
__device__ __inline__ void
fengzch-das's avatar
fengzch-das committed
152
mma_m16n8k16<__nv_bfloat16>(float *C_warp, __nv_bfloat16 *A_shared_warp, __nv_bfloat16 *B_shared_warp) {
Muyang Li's avatar
Muyang Li committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    asm volatile(
        "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
        "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
        : "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
        : "r"(((unsigned *)A_shared_warp)[0]),
          "r"(((unsigned *)A_shared_warp)[1]),
          "r"(((unsigned *)A_shared_warp)[2]),
          "r"(((unsigned *)A_shared_warp)[3]),
          "r"(((unsigned *)B_shared_warp)[0]),
          "r"(((unsigned *)B_shared_warp)[1]),
          "f"(((float *)C_warp)[0]),
          "f"(((float *)C_warp)[1]),
          "f"(((float *)C_warp)[2]),
          "f"(((float *)C_warp)[3]));
167
168
}

Muyang Li's avatar
Muyang Li committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_A(f16_t *src,
                                                       f16_t *dst,
                                                       int global_nrows,
                                                       int global_ncols,
                                                       int cta_offset_m,
                                                       int cta_offset_n,
                                                       int cta_offset_k,
                                                       int global_iter_k,
                                                       int shared_iter_k,
                                                       bool mask) {
    constexpr int threads_needed       = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
    constexpr int threads_used         = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
    constexpr int total_global_iters   = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
    constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
    constexpr int cta_step_m_or_n      = (threads_used * PACK_SIZE) / CTA_K;
    constexpr int warp_step_m_or_n     = (WARP_SIZE * PACK_SIZE) / CTA_K;
    constexpr int threads_per_row      = CTA_K / PACK_SIZE;
    constexpr int kSmemCol             = CTA_K + SMEM_PAD_A;
    bool local_mask                    = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
    int ld_col                         = (threadIdx.x % threads_per_row);
190
#pragma unroll
Muyang Li's avatar
Muyang Li committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
        int global_iter = shared_iter_k * partial_global_iters + _global_iter;
        int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
        int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
        void *dst_ptr       = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
        uint4 *src_ptr =
            (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K +
                      cta_offset_k); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols +
                                     // threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row)
                                     // * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) *
                                     // PACK_SIZE);
        if constexpr (STAGES > 1) {
            uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
            cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
        } else {
            if (local_mask & (ld_row + cta_offset_m < global_nrows))
                *(uint4 *)dst_ptr = *src_ptr;
        }
209
210
211
    }
}

Muyang Li's avatar
Muyang Li committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_B(f16_t *src,
                                                       f16_t *dst,
                                                       int global_ncols,
                                                       int cta_offset_m,
                                                       int cta_offset_n,
                                                       int cta_offset_k,
                                                       int global_iter_k,
                                                       int shared_iter_k,
                                                       bool mask) {
    constexpr int threads_needed       = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
    constexpr int threads_used         = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
    constexpr int total_global_iters   = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
    constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
    constexpr int cta_step_m_or_n      = (threads_used * PACK_SIZE) / CTA_K;
    constexpr int warp_step_m_or_n     = (WARP_SIZE * PACK_SIZE) / CTA_K;
    constexpr int threads_per_row      = CTA_K / PACK_SIZE;
    constexpr int kSmemCol             = CTA_K + SMEM_PAD_B;
    bool local_mask                    = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
231
#pragma unroll
Muyang Li's avatar
Muyang Li committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
        int global_iter = shared_iter_k * partial_global_iters + _global_iter;

        int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
        int ld_col = (threadIdx.x % threads_per_row);
        int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
        void *dst_ptr       = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
        uint4 *src_ptr      = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols +
                                   ld_row * global_ncols + ld_col * PACK_SIZE + cta_offset_k);
        if constexpr (STAGES > 1) {
            uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
            cp_async_cg_A(addr, src_ptr, local_mask);
        } else {
            if (local_mask)
                *(uint4 *)dst_ptr = *src_ptr;
        }
248
249
250
    }
}

Muyang Li's avatar
Muyang Li committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ __inline__ void global_to_share_one_stage_scales(f16_t *src,
                                                            f16_t *dst,
                                                            f16_t *src_z,
                                                            f16_t *dst_z,
                                                            int global_ncols,
                                                            int cta_offset_m,
                                                            int cta_offset_n,
                                                            int cta_offset_k,
                                                            int global_iter_k,
                                                            int shared_iter_k,
                                                            bool mask) {
    constexpr int LD_AMOUNT          = (G >= CTA_K) ? CTA_N : CTA_N * CTA_K / G;
    constexpr int threads_needed     = LD_AMOUNT / PACK_SIZE / 1;
    constexpr int threads_used       = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
    constexpr int total_global_iters = LD_AMOUNT / PACK_SIZE / threads_used;
    constexpr int threads_per_row    = CTA_N / PACK_SIZE;
    constexpr int kSmemCol           = CTA_N;
    bool local_mask                  = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
    int g_idx                        = (cta_offset_k + global_iter_k * CTA_K) / G;

    void *dst_ptr =
        (void *)(dst + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
    uint4 *src_ptr =
        (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols +
                  (threadIdx.x % threads_per_row) * PACK_SIZE);
    void *dst_ptr_z =
        (void *)(dst_z + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
    uint4 *src_ptr_z =
        (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols +
                  (threadIdx.x % threads_per_row) * PACK_SIZE);
    if (STAGES > 1) {
        uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
        cp_async_cg_A(addr, src_ptr, local_mask);
        uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
        cp_async_cg_A(addr_z, src_ptr_z, local_mask);
    } else {
        if (local_mask) {
            *(uint4 *)dst_ptr   = *src_ptr;
            *(uint4 *)dst_ptr_z = *src_ptr_z;
        }
292
293
294
    }
}

Muyang Li's avatar
Muyang Li committed
295
296
297
298
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>
__device__ __inline__ void
share_to_reg_one_stage_A(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1) {
    constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
299

Muyang Li's avatar
Muyang Li committed
300
    for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
301

Muyang Li's avatar
Muyang Li committed
302
303
304
305
        int ld_row          = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
        int ld_col          = k_0_1 * 16 + (threadIdx.x / 16) * 8 + warp_offset_k;
        int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
        void *addr_ptr      = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
306

Muyang Li's avatar
Muyang Li committed
307
308
309
        uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
        ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
    }
310
311
}

Muyang Li's avatar
Muyang Li committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>
__device__ __inline__ void share_to_reg_one_stage_B(f16_t *src,
                                                    f16_t *src_scales,
                                                    f16_t *src_zeros,
                                                    f16_t *dst,
                                                    f16_t *dst_fp16,
                                                    int warp_offset_m,
                                                    int warp_offset_n,
                                                    int warp_offset_k,
                                                    int k_0_1) {
    using f162_t = typename packed_as<f16_t, 2>::type;

    constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
    int r0                 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
    int c0                 = ((threadIdx.x / 8) % 2) * 8;
    int r                  = r0 / 4;
    int c                  = (r0 % 4) * 16 + c0;
    int c_swizzled         = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;

    if constexpr (ldmatrix) {
332
#pragma unroll
Muyang Li's avatar
Muyang Li committed
333
334
335
336
337
338
339
        for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
            void *addr_ptr =
                (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol +
                         k_0_1 * 16 + r * kSmemCol + c_swizzled + warp_offset_k);
            uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
            ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
        }
340
341
342
    }

#pragma unroll
Muyang Li's avatar
Muyang Li committed
343
344
345
346
347
348
349
350
351
352
353
    for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
        f16_t scale   = src_scales[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) +
                                 threadIdx.x / 4];
        f16_t zero    = src_zeros[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) +
                               threadIdx.x / 4];
        f162_t scale2 = f162f162(scale);
        f162_t zero2  = f162f162(zero);
        f162_t loaded[4];

        dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8),
                                reinterpret_cast<uint4 *>(loaded));
354
#pragma unroll
Muyang Li's avatar
Muyang Li committed
355
356
357
358
        for (int i = 0; i < 4; i++) {
            loaded[i] = __hfma2(loaded[i], scale2, zero2);
        }
        *reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
359
360
361
    }
}

Muyang Li's avatar
Muyang Li committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
template<typename f16_t,
         int CTA_M,
         int CTA_N,
         int CTA_K,
         int WARP_M,
         int WARP_N,
         int WARP_K,
         int STAGES,
         int G,
         int SPLITK>
__global__ void gemm_w4a16_T1(f16_t *__restrict__ A,
                              f16_t *__restrict__ B,
                              f16_t *__restrict__ scales,
                              f16_t *__restrict__ zeros,
                              f16_t *__restrict__ C,
                              int *__restrict__ semaphores,
                              int M,
                              int N,
                              int K) {
fengzch-das's avatar
fengzch-das committed
381
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
Muyang Li's avatar
Muyang Li committed
382
383
    trap_unsupported_arch();
    return;
384
#endif
Muyang Li's avatar
Muyang Li committed
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    using f162_t = typename packed_as<f16_t, 2>::type;

    constexpr int NUM_WARPS_MN    = CTA_M / WARP_M * CTA_N / WARP_N;
    constexpr int NUM_WARPS       = NUM_WARPS_MN * CTA_K / WARP_K;
    constexpr int CTA_SIZE        = NUM_WARPS * WARP_SIZE;
    constexpr int CTA_SIZE_MN     = NUM_WARPS_MN * WARP_SIZE;
    constexpr int SLICES          = CTA_K / WARP_K;
    int num_blocks_n              = (N + CTA_N - 1) / CTA_N;
    int num_blocks_m              = (M + CTA_M - 1) / CTA_M;
    int blockIdx_x                = 0;
    int blockIdx_y                = blockIdx.x % (num_blocks_m * num_blocks_n);
    int blockIdx_z                = blockIdx.x / (num_blocks_m * num_blocks_n);
    const int log_tile            = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
    int blockIdx_m                = blockIdx_y / (num_blocks_n >> log_tile);
    int blockIdx_n                = blockIdx_y % (num_blocks_n >> log_tile);
    const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
    blockIdx_m                    = block_idx_mapping.x;
    blockIdx_n                    = block_idx_mapping.y;

    float C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
    constexpr int kSmemPadKA           = CTA_K + SMEM_PAD_A;
    constexpr int kSmemPadKB           = CTA_K + SMEM_PAD_B;
    constexpr int kSmemSizeAPerStage   = CTA_M * kSmemPadKA;
    constexpr int kSmemSizeBPerStage   = CTA_N / kInterleave * kSmemPadKB;
    constexpr int kSmemSizeA           = kSmemSizeAPerStage * STAGES;
    constexpr int kSmemSizeB           = kSmemSizeBPerStage * STAGES;
    constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
    constexpr int scales_per_load      = G < CTA_K ? CTA_K / G : 1;
    constexpr int kSmemSizeScales      = CTA_N * STAGES / scales_load_interval * scales_per_load;
    constexpr int kSmemSizeZeros       = CTA_N * STAGES / scales_load_interval * scales_per_load;
    extern __shared__ half mem_shared[];
    f16_t *A_shared      = reinterpret_cast<f16_t *>(mem_shared);
    f16_t *B_shared      = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA);
    f16_t *scales_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB);
    f16_t *zeros_shared  = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales);
    float *C_shared      = reinterpret_cast<float *>(mem_shared);
    f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE];
    f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE];
    f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE];
    int cta_offset_m  = blockIdx_m * CTA_M;
    int cta_offset_n  = blockIdx_n * CTA_N;
    int cta_offset_k  = blockIdx_z * (K / SPLITK);
    int warp_mn       = threadIdx.y % NUM_WARPS_MN;
    int slice_id      = threadIdx.y / NUM_WARPS_MN;
    int warp_offset_n = (warp_mn % (CTA_N / WARP_N)) * WARP_N;
    int warp_offset_m = (warp_mn / (CTA_N / WARP_N)) * WARP_M;
    int warp_offset_k = slice_id * WARP_K;

    for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
        C_warp[i] = 0.0;

    int gemm_iters                = (K + CTA_K - 1) / CTA_K / SPLITK;
    int k_0_0_ld                  = 0;
    int k_0_0                     = 0;
    constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
440
#pragma unroll
Muyang Li's avatar
Muyang Li committed
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) {
        global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A,
                                                                                     A_shared +
                                                                                         k_0_0_ld * kSmemSizeAPerStage,
                                                                                     M,
                                                                                     K,
                                                                                     cta_offset_m,
                                                                                     cta_offset_n,
                                                                                     cta_offset_k,
                                                                                     k_0_0_ld,
                                                                                     0,
                                                                                     true);
        global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B,
                                                                                     B_shared +
                                                                                         k_0_0_ld * kSmemSizeBPerStage,
                                                                                     K,
                                                                                     cta_offset_m,
                                                                                     cta_offset_n,
                                                                                     cta_offset_k,
                                                                                     k_0_0_ld,
                                                                                     0,
                                                                                     true);
        global_to_share_one_stage_scales<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
            scales,
            scales_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
            zeros,
            zeros_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
            N,
            cta_offset_m,
            cta_offset_n,
            cta_offset_k,
            k_0_0_ld,
            0,
            k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
        if constexpr (STAGES > 1)
            __pipeline_commit();
    }
478
    if constexpr (STAGES > 1)
Muyang Li's avatar
Muyang Li committed
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
        __pipeline_wait_prior(STAGES - 2);
    __syncthreads();

    share_to_reg_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
        A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0);
    share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared,
                                                                                             scales_shared,
                                                                                             zeros_shared,
                                                                                             B_shared_warp_tmp_[0],
                                                                                             B_shared_warp_[0],
                                                                                             warp_offset_m,
                                                                                             warp_offset_n,
                                                                                             warp_offset_k,
                                                                                             0);
    constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;

    for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) {
        int ld_stage      = k_0_0_ld % STAGES;
        int compute_stage = k_0_0 % STAGES;
        f16_t *A_shared_this_compute_stage;
        f16_t *B_shared_this_compute_stage;
        f16_t *scales_shared_this_compute_stage;
        f16_t *zeros_shared_this_compute_stage;
502
503

#pragma unroll
Muyang Li's avatar
Muyang Li committed
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
        for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) {
            A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
            B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
            scales_shared_this_compute_stage =
                scales_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
            zeros_shared_this_compute_stage =
                zeros_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
            share_to_reg_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
                A_shared_this_compute_stage,
                A_shared_warp_[(iter_k + 1) % 2],
                warp_offset_m,
                warp_offset_n,
                warp_offset_k,
                (iter_k + 1) % SHARED_K_ITERS);
            if ((iter_k + 1) % kInterleave == 0) {
                if (compute_stage % 2 == 1) {
                    share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
                        B_shared_this_compute_stage,
                        scales_shared_this_compute_stage,
                        zeros_shared_this_compute_stage,
                        B_shared_warp_tmp_[1],
                        B_shared_warp_[((iter_k + 1) / 2) % 2],
                        warp_offset_m,
                        warp_offset_n,
                        warp_offset_k,
                        (iter_k + 1) % SHARED_K_ITERS);
                } else {
                    share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
                        B_shared_this_compute_stage,
                        scales_shared_this_compute_stage,
                        zeros_shared_this_compute_stage,
                        B_shared_warp_tmp_[0],
                        B_shared_warp_[((iter_k + 1) / 2) % 2],
                        warp_offset_m,
                        warp_offset_n,
                        warp_offset_k,
                        (iter_k + 1) % SHARED_K_ITERS);
                }
            } else {
                if (compute_stage % 2 == 1) {
                    share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
                        B_shared_this_compute_stage,
                        scales_shared_this_compute_stage,
                        zeros_shared_this_compute_stage,
                        B_shared_warp_tmp_[1],
                        B_shared_warp_[((iter_k + 1) / 2) % 2],
                        warp_offset_m,
                        warp_offset_n,
                        warp_offset_k,
                        (iter_k + 1) % SHARED_K_ITERS);
                } else {
                    share_to_reg_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
                        B_shared_this_compute_stage,
                        scales_shared_this_compute_stage,
                        zeros_shared_this_compute_stage,
                        B_shared_warp_tmp_[0],
                        B_shared_warp_[((iter_k + 1) / 2) % 2],
                        warp_offset_m,
                        warp_offset_n,
                        warp_offset_k,
                        (iter_k + 1) % SHARED_K_ITERS);
                }
            }
            f16_t *A_shared_warp = A_shared_warp_[iter_k % 2];
            f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];

            for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) {
                for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) {
                    mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8,
                                 A_shared_warp + i_0_3 * 8,
                                 B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
                    mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4,
                                 A_shared_warp + i_0_3 * 8,
                                 B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
                }
            }

            if (iter_k < WARP_K / INTRIN_K - 1) {
                if constexpr (STAGES == 1)
                    __syncthreads();
                global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
                    A,
                    A_shared + ld_stage * kSmemSizeAPerStage,
                    M,
                    K,
                    cta_offset_m,
                    cta_offset_n,
                    cta_offset_k,
                    k_0_0_ld,
                    iter_k,
                    k_0_0_ld < gemm_iters);
                global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
                    B,
                    B_shared + ld_stage * kSmemSizeBPerStage,
                    K,
                    cta_offset_m,
                    cta_offset_n,
                    cta_offset_k,
                    k_0_0_ld,
                    iter_k,
                    k_0_0_ld < gemm_iters);
            }

            if (iter_k == WARP_K / INTRIN_K - 2) {
                if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2) {
                    __syncthreads();
                }
                global_to_share_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
                    A,
                    A_shared + ld_stage * kSmemSizeAPerStage,
                    M,
                    K,
                    cta_offset_m,
                    cta_offset_n,
                    cta_offset_k,
                    k_0_0_ld,
                    iter_k + 1,
                    k_0_0_ld < gemm_iters);
                global_to_share_one_stage_B<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
                    B,
                    B_shared + ld_stage * kSmemSizeBPerStage,
                    K,
                    cta_offset_m,
                    cta_offset_n,
                    cta_offset_k,
                    k_0_0_ld,
                    iter_k + 1,
                    k_0_0_ld < gemm_iters);
                global_to_share_one_stage_scales<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
                    scales,
                    scales_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
                    zeros,
                    zeros_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
                    N,
                    cta_offset_m,
                    cta_offset_n,
                    cta_offset_k,
                    k_0_0_ld,
                    iter_k,
                    k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
                if constexpr (STAGES > 1) {
                    __pipeline_commit();
                    __pipeline_wait_prior(STAGES - 2);
                }
                compute_stage = (k_0_0 + 1) % STAGES;
                __syncthreads();
            }
651
652
        }
    }
Muyang Li's avatar
Muyang Li committed
653
654
655
656
    __pipeline_commit();
    __pipeline_wait_prior(0);
    __syncthreads();
    if constexpr (SLICES > 1) {
657
#pragma unroll
Muyang Li's avatar
Muyang Li committed
658
659
        for (int z = 0; z < SLICES; ++z) {
            if (slice_id == z) {
660
#pragma unroll
Muyang Li's avatar
Muyang Li committed
661
                for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
662
#pragma unroll
Muyang Li's avatar
Muyang Li committed
663
                    for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
664
#pragma unroll
Muyang Li's avatar
Muyang Li committed
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
                        for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
                            if (z > 0) {
                                C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] +=
                                    C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n +
                                             ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N +
                                             (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
                            }
                            C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
                                     ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 +
                                     (local_id % 2) + (threadIdx.x % 4) * 2] =
                                C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
                        };
                    }
                }
            }
            __syncthreads();
681
        }
Muyang Li's avatar
Muyang Li committed
682
        if (slice_id == 0) {
683
#pragma unroll
Muyang Li's avatar
Muyang Li committed
684
            for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
685
#pragma unroll
Muyang Li's avatar
Muyang Li committed
686
                for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
687
#pragma unroll
Muyang Li's avatar
Muyang Li committed
688
689
690
691
692
693
694
695
                    for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) {
                        C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] =
                            C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 +
                                     ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 +
                                     (local_id % 2) + (threadIdx.x % 4) * 2];
                    };
                }
            }
696
697
698
        }
    }

Muyang Li's avatar
Muyang Li committed
699
700
    if (slice_id == 0) {
        Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x);
701

Muyang Li's avatar
Muyang Li committed
702
703
704
        if constexpr (SPLITK > 1) {
            semaphore.fetch();
        }
705

Muyang Li's avatar
Muyang Li committed
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
        if (blockIdx_z != 0) {
            semaphore.wait(blockIdx_z);
            for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
                for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
                    for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
                        int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M +
                                        ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));

                        if (write_row < M) {
                            f162_t *existing_psum_ptr = reinterpret_cast<f162_t *>(
                                C + write_row * N + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + (local_id / 4) * 8 +
                                (local_id % 2) + (threadIdx.x % 4) * 2);

                            *existing_psum_ptr =
                                __hadd2(*existing_psum_ptr,
                                        cuda_cast<f162_t>(*reinterpret_cast<float2 *>(
                                            C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id)));
                        }
                    };
                }
726
            }
Muyang Li's avatar
Muyang Li committed
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
        } else {
            for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
                for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
                    for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
                        int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M +
                                        ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
                        if (write_row < M) {
                            *reinterpret_cast<f162_t *>(C + write_row * N + cta_offset_n + warp_offset_n +
                                                        ax1_0_1 * 16 + (local_id / 4) * 8 + (local_id % 2) +
                                                        (threadIdx.x % 4) * 2) =
                                cuda_cast<f162_t>(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
                                                                              ax1_0_1 * 8 + local_id));
                        }
                    };
                }
742
743
744
            }
        }

Muyang Li's avatar
Muyang Li committed
745
        if constexpr (SPLITK > 1) {
746

Muyang Li's avatar
Muyang Li committed
747
748
            int lock = 0;
            if (SPLITK == blockIdx_z + 1) {
749

Muyang Li's avatar
Muyang Li committed
750
751
752
753
754
755
                lock = 0;
            } else {
                lock = blockIdx_z + 1;
            }
            semaphore.release(lock);
        }
756
757
758
    }
}

Muyang Li's avatar
Muyang Li committed
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_A_T2(f16_t *src,
                                                          f16_t *dst,
                                                          int global_nrows,
                                                          int global_ncols,
                                                          int cta_offset_m,
                                                          int cta_offset_n,
                                                          int global_iter_k,
                                                          int shared_iter_k,
                                                          bool mask) {
    constexpr int threads_needed       = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
    constexpr int threads_used         = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
    constexpr int total_global_iters   = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
    constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
    constexpr int cta_step_m_or_n      = (threads_used * PACK_SIZE) / CTA_K;
    constexpr int warp_step_m_or_n     = (WARP_SIZE * PACK_SIZE) / CTA_K;
    constexpr int threads_per_row      = CTA_K / PACK_SIZE;
    constexpr int kSmemCol             = CTA_K + SMEM_PAD_A;
    bool local_mask                    = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
    int ld_col                         = (threadIdx.x % threads_per_row);
779
#pragma unroll
Muyang Li's avatar
Muyang Li committed
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
    for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
        int global_iter = shared_iter_k * partial_global_iters + _global_iter;
        int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
        int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
        void *dst_ptr       = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
        uint4 *src_ptr =
            (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE +
                      global_iter_k * CTA_K); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n *
                                              // global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols +
                                              // (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K
                                              // + (threadIdx.x % threads_per_row) * PACK_SIZE);
        if constexpr (STAGES > 1) {
            uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
            cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
        } else {
            if (local_mask & (ld_row + cta_offset_m < global_nrows))
                *(uint4 *)dst_ptr = *src_ptr;
        }
798
799
800
    }
}

Muyang Li's avatar
Muyang Li committed
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_B_T2(f16_t *src,
                                                          f16_t *dst,
                                                          int global_ncols,
                                                          int cta_offset_m,
                                                          int cta_offset_n,
                                                          int global_iter_k,
                                                          int shared_iter_k,
                                                          bool mask) {
    constexpr int threads_needed       = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
    constexpr int threads_used         = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
    constexpr int total_global_iters   = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
    constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
    constexpr int cta_step_m_or_n      = (threads_used * PACK_SIZE) / CTA_K;
    constexpr int warp_step_m_or_n     = (WARP_SIZE * PACK_SIZE) / CTA_K;
    constexpr int threads_per_row      = CTA_K / PACK_SIZE;
    constexpr int kSmemCol             = CTA_K + SMEM_PAD_B;
    bool local_mask                    = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
819
#pragma unroll
Muyang Li's avatar
Muyang Li committed
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
    for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) {
        int global_iter = shared_iter_k * partial_global_iters + _global_iter;

        int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
        int ld_col = (threadIdx.x % threads_per_row);
        int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
        void *dst_ptr       = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
        uint4 *src_ptr      = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols +
                                   ld_row * global_ncols + ld_col * PACK_SIZE);
        if constexpr (STAGES > 1) {
            uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
            cp_async_cg_A(addr, src_ptr, local_mask);
        } else {
            if (local_mask)
                *(uint4 *)dst_ptr = *src_ptr;
        }
836
837
838
    }
}

Muyang Li's avatar
Muyang Li committed
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ __inline__ void global_to_share_one_stage_scales_T2(f16_t *src,
                                                               f16_t *dst,
                                                               f16_t *src_z,
                                                               f16_t *dst_z,
                                                               int global_ncols,
                                                               int cta_offset_m,
                                                               int cta_offset_n,
                                                               int global_iter_k,
                                                               int shared_iter_k,
                                                               bool mask) {
    constexpr int threads_needed     = CTA_N / PACK_SIZE / 1;
    constexpr int threads_used       = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
    constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
    constexpr int threads_per_row    = CTA_N / PACK_SIZE;
    constexpr int kSmemCol           = CTA_N;
    bool local_mask                  = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
    int g_idx                        = global_iter_k * CTA_K / G;

    void *dst_ptr  = (void *)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
    uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
    void *dst_ptr_z = (void *)(dst_z + (threadIdx.x % threads_per_row) * PACK_SIZE);
    uint4 *src_ptr_z =
        (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
    if (STAGES > 1) {
        uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
        cp_async_cg_A(addr, src_ptr, local_mask);
        uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
        cp_async_cg_A(addr_z, src_ptr_z, local_mask);
    } else {
        if (local_mask) {
            *(uint4 *)dst_ptr   = *src_ptr;
            *(uint4 *)dst_ptr_z = *src_ptr_z;
        }
873
874
875
    }
}

Muyang Li's avatar
Muyang Li committed
876
877
878
879
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>
__device__ __inline__ void
share_to_reg_one_stage_A_T2(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int k_0_1) {
    constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
880

Muyang Li's avatar
Muyang Li committed
881
    for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
882

Muyang Li's avatar
Muyang Li committed
883
884
885
886
        int ld_row          = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
        int ld_col          = k_0_1 * 16 + (threadIdx.x / 16) * 8;
        int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
        void *addr_ptr      = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
887

Muyang Li's avatar
Muyang Li committed
888
889
890
        uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
        ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
    }
891
892
}

Muyang Li's avatar
Muyang Li committed
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>
__device__ __inline__ void share_to_reg_one_stage_B_T2(f16_t *src,
                                                       f16_t *src_scales,
                                                       f16_t *src_zeros,
                                                       f16_t *dst,
                                                       f16_t *dst_fp16,
                                                       int warp_offset_m,
                                                       int warp_offset_n,
                                                       int k_0_1) {
    using f162_t           = typename packed_as<f16_t, 2>::type;
    constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
    int r0                 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
    int c0                 = ((threadIdx.x / 8) % 2) * 8;
    int r                  = r0 / 4;
    int c                  = (r0 % 4) * 16 + c0;
    int c_swizzled         = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;

    if constexpr (ldmatrix) {
911
#pragma unroll
Muyang Li's avatar
Muyang Li committed
912
913
914
915
916
917
918
        for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
            void *addr_ptr =
                (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol +
                         k_0_1 * 16 + r * kSmemCol + c_swizzled);
            uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
            ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
        }
919
920
921
    }

#pragma unroll
Muyang Li's avatar
Muyang Li committed
922
923
924
925
926
927
928
929
    for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) {
        f16_t scale   = src_scales[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
        f16_t zero    = src_zeros[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
        f162_t scale2 = f162f162(scale);
        f162_t zero2  = f162f162(zero);
        f162_t loaded[4];
        dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8),
                                reinterpret_cast<uint4 *>(loaded));
930
#pragma unroll
Muyang Li's avatar
Muyang Li committed
931
932
933
934
        for (int i = 0; i < 4; i++) {
            loaded[i] = __hfma2(loaded[i], scale2, zero2);
        }
        *reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
935
936
937
    }
}

Muyang Li's avatar
Muyang Li committed
938
939
940
941
942
943
944
945
946
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
__global__ void gemm_w4a16_T2(f16_t *__restrict__ A,
                              f16_t *__restrict__ B,
                              f16_t *__restrict__ scales,
                              f16_t *__restrict__ zeros,
                              f16_t *__restrict__ C,
                              int M,
                              int N,
                              int K) {
fengzch-das's avatar
fengzch-das committed
947
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
Muyang Li's avatar
Muyang Li committed
948
949
    trap_unsupported_arch();
    return;
950
#endif
Muyang Li's avatar
Muyang Li committed
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
    using f162_t                  = typename packed_as<f16_t, 2>::type;
    constexpr int NUM_WARPS       = CTA_M / WARP_M * CTA_N / WARP_N;
    constexpr int CTA_SIZE        = NUM_WARPS * WARP_SIZE;
    int num_blocks_n              = (N + CTA_N - 1) / CTA_N;
    int num_blocks_m              = (M + CTA_M - 1) / CTA_M;
    int blockIdx_x                = 0;
    int blockIdx_y                = blockIdx.x % (num_blocks_m * num_blocks_n);
    int blockIdx_z                = blockIdx.x / (num_blocks_m * num_blocks_n);
    const int log_tile            = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
    int blockIdx_m                = blockIdx_y / (num_blocks_n >> log_tile);
    int blockIdx_n                = blockIdx_y % (num_blocks_n >> log_tile);
    const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
    blockIdx_m                    = block_idx_mapping.x;
    blockIdx_n                    = block_idx_mapping.y;

    float C_warp[CTA_M * CTA_N / CTA_SIZE];
    constexpr int kSmemPadKA           = CTA_K + SMEM_PAD_A;
    constexpr int kSmemPadKB           = CTA_K + SMEM_PAD_B;
    constexpr int kSmemSizeAPerStage   = CTA_M * kSmemPadKA;
    constexpr int kSmemSizeBPerStage   = CTA_N / kInterleave * kSmemPadKB;
    constexpr int kSmemSizeA           = kSmemSizeAPerStage * STAGES;
    constexpr int kSmemSizeB           = kSmemSizeBPerStage * STAGES;
    constexpr int kSmemSizeScales      = CTA_N * STAGES / 2;
    constexpr int kSmemSizeZeros       = CTA_N * STAGES / 2;
    constexpr int scales_load_interval = G / CTA_K;
    extern __shared__ half mem_shared[];
    f16_t *A_shared      = reinterpret_cast<f16_t *>(mem_shared);
    f16_t *B_shared      = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA);
    f16_t *scales_shared = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB);
    f16_t *zeros_shared  = reinterpret_cast<f16_t *>(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales);
    f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE];
    f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE];
    f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE];
    int cta_offset_m  = blockIdx_m * CTA_M;
    int cta_offset_n  = blockIdx_n * CTA_N;
    int warp_offset_m = (threadIdx.y % (CTA_M / WARP_M)) * WARP_M;
    int warp_offset_n = (threadIdx.y / (CTA_M / WARP_M)) * WARP_N;

    for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE; i++)
        C_warp[i] = 0.0;

    int gemm_iters                = (K + CTA_K - 1) / CTA_K;
    int k_0_0_ld                  = 0;
    int k_0_0                     = 0;
    constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
996
#pragma unroll
Muyang Li's avatar
Muyang Li committed
997
998
999
1000
1001
    for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) {
        global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
            A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
        global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(
            B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
1002
        global_to_share_one_stage_scales_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
Muyang Li's avatar
Muyang Li committed
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
            scales,
            scales_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
            zeros,
            zeros_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
            N,
            cta_offset_m,
            cta_offset_n,
            k_0_0_ld,
            0,
            k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
1013
        if constexpr (STAGES > 1)
Muyang Li's avatar
Muyang Li committed
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
            __pipeline_commit();
    }
    if constexpr (STAGES > 1)
        __pipeline_wait_prior(STAGES - 2);
    __syncthreads();

    share_to_reg_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
        A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0);
    share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared,
                                                                                                scales_shared,
                                                                                                zeros_shared,
                                                                                                B_shared_warp_tmp_[0],
                                                                                                B_shared_warp_[0],
                                                                                                warp_offset_m,
                                                                                                warp_offset_n,
                                                                                                0);
    constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;

    for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) {
        int ld_stage      = k_0_0_ld % STAGES;
        int compute_stage = k_0_0 % STAGES;
        f16_t *A_shared_this_compute_stage;
        f16_t *B_shared_this_compute_stage;
        f16_t *scales_shared_this_compute_stage;
        f16_t *zeros_shared_this_compute_stage;

        for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) {
            A_shared_this_compute_stage      = A_shared + compute_stage * kSmemSizeAPerStage;
            B_shared_this_compute_stage      = B_shared + compute_stage * kSmemSizeBPerStage;
            scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval) * CTA_N;
            zeros_shared_this_compute_stage  = zeros_shared + (compute_stage / scales_load_interval) * CTA_N;
            share_to_reg_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
                A_shared_this_compute_stage,
                A_shared_warp_[(iter_k + 1) % 2],
                warp_offset_m,
                warp_offset_n,
                (iter_k + 1) % SHARED_K_ITERS);
            if ((iter_k + 1) % kInterleave == 0) {
                if (compute_stage % 2 == 1) {
                    share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
                        B_shared_this_compute_stage,
                        scales_shared_this_compute_stage,
                        zeros_shared_this_compute_stage,
                        B_shared_warp_tmp_[1],
                        B_shared_warp_[((iter_k + 1) / 2) % 2],
                        warp_offset_m,
                        warp_offset_n,
                        (iter_k + 1) % SHARED_K_ITERS);
                } else {
                    share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
                        B_shared_this_compute_stage,
                        scales_shared_this_compute_stage,
                        zeros_shared_this_compute_stage,
                        B_shared_warp_tmp_[0],
                        B_shared_warp_[((iter_k + 1) / 2) % 2],
                        warp_offset_m,
                        warp_offset_n,
                        (iter_k + 1) % SHARED_K_ITERS);
                }
            } else {
                if (compute_stage % 2 == 1) {
                    share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
                        B_shared_this_compute_stage,
                        scales_shared_this_compute_stage,
                        zeros_shared_this_compute_stage,
                        B_shared_warp_tmp_[1],
                        B_shared_warp_[((iter_k + 1) / 2) % 2],
                        warp_offset_m,
                        warp_offset_n,
                        (iter_k + 1) % SHARED_K_ITERS);
                } else {
                    share_to_reg_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
                        B_shared_this_compute_stage,
                        scales_shared_this_compute_stage,
                        zeros_shared_this_compute_stage,
                        B_shared_warp_tmp_[0],
                        B_shared_warp_[((iter_k + 1) / 2) % 2],
                        warp_offset_m,
                        warp_offset_n,
                        (iter_k + 1) % SHARED_K_ITERS);
                }
            }
            __syncthreads();
            f16_t *A_shared_warp = A_shared_warp_[iter_k % 2];
            f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
            for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) {
                for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) {
                    mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8,
                                 A_shared_warp + i_0_3 * 8,
                                 B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
                    mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4,
                                 A_shared_warp + i_0_3 * 8,
                                 B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
                }
            }

            if (iter_k < WARP_K / INTRIN_K - 1) {
                if constexpr (STAGES == 1)
                    __syncthreads();
                global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
                    A,
                    A_shared + ld_stage * kSmemSizeAPerStage,
                    M,
                    K,
                    cta_offset_m,
                    cta_offset_n,
                    k_0_0_ld,
                    iter_k,
                    k_0_0_ld < gemm_iters);
                global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
                    B,
                    B_shared + ld_stage * kSmemSizeBPerStage,
                    K,
                    cta_offset_m,
                    cta_offset_n,
                    k_0_0_ld,
                    iter_k,
                    k_0_0_ld < gemm_iters);
            }

            if (iter_k == WARP_K / INTRIN_K - 2) {
                if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2) {
                    __syncthreads();
                }
                global_to_share_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
                    A,
                    A_shared + ld_stage * kSmemSizeAPerStage,
                    M,
                    K,
                    cta_offset_m,
                    cta_offset_n,
                    k_0_0_ld,
                    iter_k + 1,
                    k_0_0_ld < gemm_iters);
                global_to_share_one_stage_B_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(
                    B,
                    B_shared + ld_stage * kSmemSizeBPerStage,
                    K,
                    cta_offset_m,
                    cta_offset_n,
                    k_0_0_ld,
                    iter_k + 1,
                    k_0_0_ld < gemm_iters);
                global_to_share_one_stage_scales_T2<f16_t, CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
                    scales,
                    scales_shared + (ld_stage / scales_load_interval) * CTA_N,
                    zeros,
                    zeros_shared + (ld_stage / scales_load_interval) * CTA_N,
                    N,
                    cta_offset_m,
                    cta_offset_n,
                    k_0_0_ld,
                    iter_k,
                    k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
                if constexpr (STAGES > 1) {
                    __pipeline_commit();
                    __pipeline_wait_prior(STAGES - 2);
                }
                compute_stage = (k_0_0 + 1) % STAGES;
                __syncthreads();
            }
1175
1176
        }
    }
Muyang Li's avatar
Muyang Li committed
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
    for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) {
        for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) {
            for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) {
                int write_row =
                    cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
                if (write_row < M) {
                    *reinterpret_cast<f162_t *>(C + write_row * N + cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
                                                (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =
                        cuda_cast<f162_t>(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
                                                                      ax1_0_1 * 8 + local_id));
                }
            };
1189
1190
1191
1192
        }
    }
}

Muyang Li's avatar
Muyang Li committed
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, Tensor _zeros) {
    auto output_shape    = _in_feats.shape.dataExtent;
    output_shape.back()  = _kernel.size(0) * kInterleave;
    int num_in_feats     = _in_feats.numel() / _in_feats.size(-1);
    int num_in_channels  = _in_feats.size(-1);
    auto options         = Tensor::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
    auto options_int     = Tensor::TensorOptions().dtype(Tensor::INT32).device(_in_feats.device());
    Tensor _out_feats    = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
    int num_out_feats    = _out_feats.numel() / _out_feats.size(-1);
    int num_out_channels = _out_feats.size(-1);

    if (_in_feats.scalar_type() == Tensor::FP16) {
        using f16_t = half;

        auto in_feats  = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
        auto kernel    = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>());
        auto scales    = reinterpret_cast<f16_t *>(_scales.data_ptr());
        auto zeros     = reinterpret_cast<f16_t *>(_zeros.data_ptr());
        auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr());

        if (num_out_feats <= 32) {
            constexpr int G      = 128;
            constexpr int CTA_M  = 16;
            constexpr int CTA_N  = 128;
            constexpr int CTA_K  = 128;
            constexpr int WARP_M = 16;
            constexpr int WARP_N = 32;
            constexpr int WARP_K = 64;
            constexpr int SPLITK = 2;
            constexpr int STAGES = 4;
            KERNEL_LAUNCH_CODE
        } else if (num_out_feats <= 64) {

            constexpr int G      = 128;
            constexpr int CTA_M  = 16;
            constexpr int CTA_N  = 128;
            constexpr int CTA_K  = 128;
            constexpr int WARP_M = 16;
            constexpr int WARP_N = 32;
            constexpr int WARP_K = 64;
            constexpr int SPLITK = 1;
            constexpr int STAGES = 3;
            KERNEL_LAUNCH_CODE
        } else if (num_out_feats <= 128) {
            constexpr int G      = 128;
            constexpr int CTA_M  = 32;
            constexpr int CTA_N  = 128;
            constexpr int CTA_K  = 128;
            constexpr int WARP_M = 32;
            constexpr int WARP_N = 32;
            constexpr int WARP_K = 64;
            constexpr int SPLITK = 1;
            constexpr int STAGES = 4;
            KERNEL_LAUNCH_CODE
        } else if (num_out_feats <= 192) {
            constexpr int G      = 128;
            constexpr int CTA_M  = 64;
            constexpr int CTA_N  = 128;
            constexpr int CTA_K  = 64;
            constexpr int WARP_M = 64;
            constexpr int WARP_N = 32;
            constexpr int WARP_K = 64;
            constexpr int SPLITK = 1;
            constexpr int STAGES = 4;
            KERNEL_LAUNCH_CODE
        } else {
            constexpr int G      = 128;
            constexpr int CTA_M  = 64;
            constexpr int CTA_N  = 128;
            constexpr int CTA_K  = 64;
            constexpr int WARP_M = 64;
            constexpr int WARP_N = 32;
            constexpr int WARP_K = 64;
            constexpr int STAGES = 4;

            constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
            constexpr int kSmemByteSize =
                (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES *
                sizeof(f16_t);
            if (kSmemByteSize >= 99 * 1024) {
                printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
                return _out_feats;
            }
            int j_factors1 = num_out_channels / CTA_N / 1;
            dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
            dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
            auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
fengzch-das's avatar
fengzch-das committed
1280
1281
            cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
            kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
Muyang Li's avatar
Muyang Li committed
1282
1283
1284
                in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
        }
    } else if (_in_feats.scalar_type() == Tensor::BF16) {
fengzch-das's avatar
fengzch-das committed
1285
        using f16_t = __nv_bfloat16;
Muyang Li's avatar
Muyang Li committed
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359

        auto in_feats  = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
        auto kernel    = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>());
        auto scales    = reinterpret_cast<f16_t *>(_scales.data_ptr());
        auto zeros     = reinterpret_cast<f16_t *>(_zeros.data_ptr());
        auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr());

        if (num_out_feats <= 32) {
            constexpr int G      = 128;
            constexpr int CTA_M  = 16;
            constexpr int CTA_N  = 128;
            constexpr int CTA_K  = 128;
            constexpr int WARP_M = 16;
            constexpr int WARP_N = 32;
            constexpr int WARP_K = 64;
            constexpr int SPLITK = 2;
            constexpr int STAGES = 4;
            KERNEL_LAUNCH_CODE
        } else if (num_out_feats <= 64) {

            constexpr int G      = 128;
            constexpr int CTA_M  = 16;
            constexpr int CTA_N  = 128;
            constexpr int CTA_K  = 128;
            constexpr int WARP_M = 16;
            constexpr int WARP_N = 32;
            constexpr int WARP_K = 64;
            constexpr int SPLITK = 1;
            constexpr int STAGES = 3;
            KERNEL_LAUNCH_CODE
        } else if (num_out_feats <= 128) {
            constexpr int G      = 128;
            constexpr int CTA_M  = 32;
            constexpr int CTA_N  = 128;
            constexpr int CTA_K  = 128;
            constexpr int WARP_M = 32;
            constexpr int WARP_N = 32;
            constexpr int WARP_K = 64;
            constexpr int SPLITK = 1;
            constexpr int STAGES = 4;
            KERNEL_LAUNCH_CODE
        } else if (num_out_feats <= 192) {
            constexpr int G      = 128;
            constexpr int CTA_M  = 64;
            constexpr int CTA_N  = 128;
            constexpr int CTA_K  = 64;
            constexpr int WARP_M = 64;
            constexpr int WARP_N = 32;
            constexpr int WARP_K = 64;
            constexpr int SPLITK = 1;
            constexpr int STAGES = 4;
            KERNEL_LAUNCH_CODE
        } else {
            constexpr int G      = 128;
            constexpr int CTA_M  = 64;
            constexpr int CTA_N  = 128;
            constexpr int CTA_K  = 64;
            constexpr int WARP_M = 64;
            constexpr int WARP_N = 32;
            constexpr int WARP_K = 64;
            constexpr int STAGES = 4;

            constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
            constexpr int kSmemByteSize =
                (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES *
                sizeof(f16_t);
            if (kSmemByteSize >= 99 * 1024) {
                printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
                return _out_feats;
            }
            int j_factors1 = num_out_channels / CTA_N / 1;
            dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
            dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
            auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
fengzch-das's avatar
fengzch-das committed
1360
1361
            cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
            kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
Muyang Li's avatar
Muyang Li committed
1362
1363
1364
1365
                in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
        }
    } else {
        throw std::runtime_error("Unsupported input type");
1366
    }
Muyang Li's avatar
Muyang Li committed
1367
1368
1369

    return _out_feats;
}