/****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include #include #include //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cute { } // namespace cute namespace flash { //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void gemm(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, TiledMma tiled_mma, TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } } cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void gemm_A_in_regs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsB, TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, ThrCopy smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } } //////////////////////////////////////////////////////////////////////////////////////////////////// // Blocks until all but N previous cp.async.commit_group operations have committed. // This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all // (which is equivalent to commit_group then wait_group 0). // Instead we just call cp.async.wait_group 0, which is slightly faster. // https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 template CUTE_HOST_DEVICE void cp_async_wait() { #if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void copy(TiledCopy thr_copy, Tensor const& S, Tensor& D, Tensor const& identity_MN, Tensor const& predicate_K, int max_MN = 0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K // There's no case where !Clear_OOB_K && Clear_OOB_MN static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); #pragma unroll for (int m = 0; m < size<1>(S); ++m) { if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { copy(thr_copy, S(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { clear(D(_, m, k)); } } } else if (Clear_OOB_MN) { clear(D(_, m, _)); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct MaxOp { __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; } }; template<> struct MaxOp { // This is slightly faster __device__ inline float operator()(float const& x, float const& y) { return max(x, y); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SumOp { __device__ inline T operator()(T const& x, T const& y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Allreduce { static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); template static __device__ inline T run(T x, Operator& op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); return Allreduce::run(x, op); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Allreduce<2> { template static __device__ inline T run(T x, Operator& op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) template inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); } //////////////////////////////////////////////////////////////////////////////////////////////////// // Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. template inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { using X = Underscore; static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); static_assert(mma_shape_K == 8 || mma_shape_K == 16); constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), get<0, 1>(l), get<1, 1, 1>(l)); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ auto convert_type(Tensor const& tensor) { using From_type = typename Engine::value_type; constexpr int numel = decltype(size(tensor))::value; cutlass::NumericArrayConverter convert_op; // HACK: this requires tensor to be "contiguous" auto frag = convert_op(*reinterpret_cast*>(tensor.data())); return make_tensor(make_rmem_ptr(&frag), tensor.layout()); } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace flash