// Copyright (c) OpenMMLab. All rights reserved. #pragma once #include "common.h" #include "cta_iterator.h" #include "warp_iterator.h" //#include namespace turbomind { // __inline__ __device__ void // mma_m16n8k8_row_col(Array& d, const Array& a, const Array& b, Array& c) // { // #if TURBOMIND_ARCH_SM75 // uint32_t const* A = reinterpret_cast(&a); // uint32_t const* B = reinterpret_cast(&b); // float const* C = reinterpret_cast(&c); // float* D = reinterpret_cast(&d); // asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " // "{%4,%5}, {%6}, {%7,%8,%9,%10};\n" // : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) // : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); // #else // assert(TURBOMIND_ARCH_SM75); // #endif // } __inline__ __device__ void mma_m16n8k16_row_col(Array& d, const Array& a, const Array& b, Array& c) { // #if TURBOMIND_ARCH_SM80 // uint32_t const* A = reinterpret_cast(&a); // uint32_t const* B = reinterpret_cast(&b); // float const* C = reinterpret_cast(&c); // float* D = reinterpret_cast(&d); // asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " // "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" // : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) // : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); // #else // const Array* _a = (const Array*)&a; // const Array* _b = (const Array*)&b; // mma_m16n8k8_row_col(d, _a[0], _b[0], c); // mma_m16n8k8_row_col(d, _a[1], _b[1], d); // #endif } __inline__ __device__ uint transpose_m8n8_b16_warp_shuffle(uint value, int lane_id) { // int src_lane = lane_id / 8 + lane_id % 4 * 8; // uint u0 = __shfl_sync(0xffffffff, value, src_lane); // uint u1 = __shfl_sync(0xffffffff, value, src_lane + 4); short2 r; // if (lane_id % 8 < 4) { // r.x = ((short2&)u0).x; // r.y = ((short2&)u1).x; // } // else { // r.x = ((short2&)u0).y; // r.y = ((short2&)u1).y; // } return (uint&)r; } // #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8) // __inline__ __device__ uint transpose_m8n8_b16_movmatrix(uint a) // { // #if TURBOMIND_ARCH_SM75 // uint d; // asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" : "=r"(d) : "r"(a)); // return d; // #else // assert(TURBOMIND_ARCH_SM75); // return 0; // #endif // } // #endif __inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id) { // #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8) // (void)lane_id; // return transpose_m8n8_b16_movmatrix(a); // #else // return transpose_m8n8_b16_warp_shuffle(a, lane_id); // #endif return a; } namespace ops { __inline__ __device__ float4 operator+(const float4& a, const float4& b) { return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; } __inline__ __device__ float2 operator+(const float2& a, const float2& b) { return {a.x + b.x, a.y + b.y}; } } // namespace ops template struct Gemm { static constexpr int kWarpCountM = CTA_M / WARP_M; static constexpr int kWarpCountN = CTA_N / WARP_N; static constexpr int kWarpCountK = CTA_K / WARP_K; static constexpr int kWarpCountMN = kWarpCountM * kWarpCountN; static constexpr int kWarpCount = kWarpCountMN * kWarpCountK; static constexpr int SLICES = kWarpCountK; static constexpr int SLICE_K = CTA_K / SLICES; static_assert(SLICE_K % WARP_K == 0, "infeasible sliced-k setting"); using IteratorA = turbomind::IteratorA; using IteratorQ = turbomind::IteratorQ; using IteratorB = turbomind::IteratorB; static constexpr int OP_M = 16; static constexpr int OP_N = 8; static constexpr int OP_K = 16; using WarpIterA = turbomind::WarpIteratorA; using WarpIterB = turbomind::WarpIteratorB; __device__ void warp_mma(IteratorA& iter_A, IteratorQ& iter_Q, IteratorB& iter_B, WarpIterA& warp_iter_A, WarpIterB& warp_iter_B, float* accum, int slice_id, int& gemm_iter) { // constexpr int ITER_M = WARP_M / OP_M; // constexpr int ITER_N = WARP_N / OP_N; // constexpr int ITER_K = WARP_K / OP_K; // constexpr int kBatchA = (IteratorA::kIterCount + ITER_K - 1) / ITER_K; // constexpr int kBatchQ = (IteratorQ::kIterCount + ITER_K - 1) / ITER_K; // constexpr int kBatchB = (IteratorB::kIterCount + ITER_K - 1) / ITER_K; // auto frag_C_ptr = (Array*)accum; // [ITER_N, ITER_M] // PRAGMA_UNROLL // for (int iter_k = 0; iter_k < ITER_K; ++iter_k) { // warp_iter_A.load(warp_frag_A_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K); // warp_iter_B.load(warp_frag_B_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K); // auto warp_frag_A = warp_frag_A_[iter_k % 2]; // auto warp_frag_B = warp_frag_B_[iter_k % 2]; // PRAGMA_UNROLL // for (int iter_m = 0; iter_m < ITER_M; ++iter_m) { // PRAGMA_UNROLL // for (int iter_n = 0; iter_n < ITER_N; ++iter_n) { // auto& frag_A = warp_frag_A[iter_m]; // auto& frag_B = warp_frag_B[iter_n]; // auto& frag_C = frag_C_ptr[iter_n * ITER_M + iter_m]; // mma_m16n8k16_row_col(frag_C, frag_A, frag_B, frag_C); // } // } // if (iter_k < ITER_K - 1) { // iter_A.prefetch_batch(iter_k, kBatchA, gemm_iter > 0); // iter_Q.prefetch_batch(iter_k, kBatchQ, gemm_iter > 0); // iter_B.prefetch_batch(iter_k, kBatchB, gemm_iter > 0); // } // if (iter_k == ITER_K - 2) { // iter_A.prefetch_batch(iter_k + 1, kBatchA, gemm_iter > 0); // iter_Q.prefetch_batch(iter_k + 1, kBatchQ, gemm_iter > 0); // iter_B.prefetch_batch(iter_k + 1, kBatchB, gemm_iter > 0); // __pipeline_commit(); // __pipeline_wait_prior(STAGES - 2); // sync_slice(slice_id); // iter_A.next_stage(); // iter_Q.next_stage(); // iter_B.next_stage(); // warp_iter_A.next_stage(); // warp_iter_B.next_stage(); // --gemm_iter; // } // } } template __device__ static void copy(T (&dst)[N], const T (&src)[N]) { PRAGMA_UNROLL for (int i = 0; i < N; ++i) { dst[i] = src[i]; } } template __device__ static void clear(T (&dst)[N]) { PRAGMA_UNROLL for (int i = 0; i < N; ++i) { dst[i] = T{}; } } __device__ void sync_slice(int slice_id) { // if constexpr (SLICES == 1) { // __syncthreads(); // } // else { // constexpr int SLICE_GROUP = (SLICES + 7) / 8; // constexpr uint32_t num_threads = kWarpCountMN * WARP_SIZE; // const uint32_t barrier_id = slice_id / SLICE_GROUP + 1; // // asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads)); // } } __device__ void load_partial(float* tb_frag_C, const float* partial_C, int cta, int slice_id) { // if (slice_id == 0) { // PRAGMA_UNROLL // for (int i = 0; i < CTA_N; ++i) { // tb_frag_C[i] += partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x]; // } // } } __device__ void store_partial(float* partial_C, const float* tb_frag_C, int cta, int slice_id) { // if (slice_id == 0) { // PRAGMA_UNROLL // for (int i = 0; i < CTA_N; ++i) { // partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x] = tb_frag_C[i]; // } // } } template __device__ void store_accum(float* tb_frag_C, float* tb_smem_C, half* C, int m, int n, int cta_m, int cta_n, int warp_id_m, int warp_id_n, int lane_id, int slice_id) { // if (slice_id != 0) { // return; // } // // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-c // PRAGMA_UNROLL // for (int i = 0; i < WARP_N / OP_N; ++i) { // const float2* frag_C = (float2*)&tb_frag_C[i * WARP_M / OP_M * 4]; // const int nn = cta_n + warp_id_n * WARP_N + i * OP_N + lane_id / 4; // PRAGMA_UNROLL // for (int j = 0; j < WARP_M / OP_M; ++j) { // PRAGMA_UNROLL // for (int x = 0; x < 2; ++x) { // const int mm = cta_m + warp_id_m * WARP_M + j * OP_M + x * 8 + lane_id % 4 * 2; // // convert to half // half2 half_C = __float22half2_rn(frag_C[j * 2 + x]); // // transpose 8x8 accum tile // uint trans_C = transpose_m8n8_b16((uint&)half_C, lane_id); // // store to global memory // OutputOps::template apply(trans_C, mm, nn, C, m, n); // } // } // } } __device__ void sum_slices(float* tb_frag_C, float* tb_smem_C, int warp_id_m, int warp_id_n, int lane_id, int slice_id) { // int offset_m = warp_id_m * WARP_M / OP_M; // int offset_n = warp_id_n * WARP_N / OP_N; // PRAGMA_UNROLL // for (int z = 0; z < SLICES; ++z) { // if (slice_id == z) { // PRAGMA_UNROLL // for (int i = 0; i < WARP_N / OP_N; ++i) { // PRAGMA_UNROLL // for (int j = 0; j < WARP_M / OP_M; ++j) { // PRAGMA_UNROLL // for (int x = 0; x < 4; ++x) { // int src = (i * WARP_M / OP_M + j) * 4 + x; // int dst = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x; // if (z > 0) { // using namespace ops; // tb_frag_C[src] = tb_smem_C[dst * WARP_SIZE + lane_id] + tb_frag_C[src]; // } // tb_smem_C[dst * WARP_SIZE + lane_id] = tb_frag_C[src]; // } // } // } // } // __syncthreads(); // } // if (slice_id == 0) { // PRAGMA_UNROLL // for (int i = 0; i < WARP_N / OP_N; ++i) { // PRAGMA_UNROLL // for (int j = 0; j < WARP_M / OP_M; ++j) { // PRAGMA_UNROLL // for (int x = 0; x < 4; ++x) { // int src = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x; // int dst = (i * WARP_M / OP_M + j) * 4 + x; // tb_frag_C[dst] = tb_smem_C[src * WARP_SIZE + lane_id]; // } // } // } // } } // Array warp_frag_A_[2][WARP_M / OP_M]; // Array warp_frag_B_[2][WARP_N / OP_N]; __device__ void run_v2(half* __restrict__ C, const uint* __restrict__ A, const half* __restrict__ B, const half2* __restrict__ Q, int M, int N, int K, int output_op_idx) { // static_assert(WARP_M % OP_N == 0); // float tb_frag_C[(WARP_N / OP_N) * (WARP_M / OP_M) * 4]; // extern __shared__ uint8_t smem[]; // const int warp_id = threadIdx.x / WARP_SIZE; // const int lane_id = threadIdx.x % WARP_SIZE; // const int warp_id_m = warp_id % kWarpCountM; // const int warp_id_nk = warp_id / kWarpCountM; // const int warp_id_n = warp_id_nk % kWarpCountN; // const int warp_id_k = warp_id_nk / kWarpCountN; // const int warp_id_mn = warp_id_n * kWarpCountM + warp_id_m; // const int slice_id = warp_id_k; // const int cta_k = slice_id * SLICE_K; // sliced-k offset // const int cta_m = blockIdx.x * CTA_M; // const int cta_n = blockIdx.y * CTA_N; // // each slice has its own partition of smem // uint4* const tb_smem_A = (uint4*)(smem + IteratorA::kSmemByteSize * slice_id); // half* const tb_smem_B = (half*)(smem + IteratorA::kSmemByteSize * SLICES + IteratorB::kSmemByteSize * slice_id); // // [CTA_N / OP_N, CTA_M / OP_M, 4, WARP_SIZE], all mn fragments in CTA // float* const tb_smem_C = (float*)smem; // __shared__ typename IteratorQ::Storage tb_smem_Q_storage; // auto tb_smem_Q = tb_smem_Q_storage.data[slice_id]; // IteratorA iter_A{A, tb_smem_A, M, K, cta_m, cta_k, warp_id_mn, lane_id}; // IteratorQ iter_Q{Q, tb_smem_Q, M, K, cta_m, cta_k, warp_id_mn, lane_id}; // IteratorB iter_B{B, tb_smem_B, K, N, cta_n, cta_k, warp_id_mn, lane_id}; // const int offset_m = warp_id_m * WARP_M + lane_id; // WarpIterA warp_iter_A(iter_A.smem_, iter_Q.smem_, warp_id, lane_id, offset_m, cta_k); // WarpIterB warp_iter_B(iter_B.smem_int_ptr_, warp_id_n, lane_id, 0); // int gemm_iter = (K + CTA_K - 1) / CTA_K; // PRAGMA_UNROLL // for (int stage = 0; stage < STAGES - 1; ++stage, --gemm_iter) { // iter_A.prefetch_stage(gemm_iter > 0); // iter_Q.prefetch_stage(gemm_iter > 0); // iter_B.prefetch_stage(gemm_iter > 0); // __pipeline_commit(); // } // clear(tb_frag_C); // __pipeline_wait_prior(STAGES - 2); // sync_slice(slice_id); // warp_iter_A.load(warp_frag_A_[0], 0); // warp_iter_B.load(warp_frag_B_[0], 0); // PRAGMA_NO_UNROLL // for (; gemm_iter > -STAGES + 1;) { // warp_mma(iter_A, iter_Q, iter_B, warp_iter_A, warp_iter_B, tb_frag_C, slice_id, gemm_iter); // } // __pipeline_commit(); // __pipeline_wait_prior(0); // __syncthreads(); // if constexpr (SLICES > 1) { // sum_slices(tb_frag_C, tb_smem_C, warp_id_m, warp_id_n, lane_id, slice_id); // } // switch (output_op_idx) { // case 0: // store_accum<0>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id); // break; // case 1: // store_accum<1>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id); // break; // default: // return; // } } }; template __global__ void gemm_s4_f16_nn(half* __restrict__ C, const uint* __restrict__ A, const half* __restrict__ B, const half2* __restrict__ Q, int M, int N, int K, int output_op_idx) { Gemm{}.run_v2(C, A, B, Q, M, N, K, output_op_idx); } } // namespace turbomind