"vscode:/vscode.git/clone" did not exist on "82703dff7395dbc80d320af2d101d3ea530d2a25"
Commit c28eca99 authored by Shengyu Liu's avatar Shengyu Liu
Browse files

Reorganize files and add sparse prefill/decoding kernels on hopper

parent 261330bb
#include "common/mask.cuh"
#include "common/utils.hpp"
#include "fmha_cutlass_fwd_sm100.cuh"
#include "interface.h"
#include <Python.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_bf16.h>
#include <torch/library.h>
#include "common/mask.cuh"
#include "common/utils.hpp"
#include "fmha_cutlass_fwd_sm100.cuh"
template <class Mask, class Varlen, class Element, class ElementOut, class Mla>
void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen,
......
#include <torch/python.h>
#pragma once
#include <ATen/Tensor.h>
void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v,
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
......@@ -10,8 +12,3 @@ void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Ten
at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv,
at::Tensor dq, at::Tensor dk, at::Tensor dv,
int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fwd", &FMHACutlassSM100FwdRun);
m.def("bwd", &FMHACutlassSM100BwdRun);
}
......@@ -34,6 +34,7 @@
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
#include "utils.h" // for IS_SM100
namespace cutlass::fmha::kernel {
......@@ -138,6 +139,7 @@ struct FmhaKernelBwdConvert {
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if IS_SM100
if (params.ptr_src_dQ != nullptr) {
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape));
}
......@@ -147,6 +149,11 @@ struct FmhaKernelBwdConvert {
if (params.ptr_src_dV != nullptr) {
copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape), get<3>(params.problem_shape));
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n");
}
#endif
}
};
......
......@@ -34,6 +34,7 @@
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
#include "utils.h" // for IS_SM100
namespace cutlass::fmha::kernel {
......@@ -104,6 +105,7 @@ struct FmhaKernelBwdSumOdO {
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if IS_SM100
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);
......@@ -155,6 +157,11 @@ struct FmhaKernelBwdSumOdO {
}
}
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n");
}
#endif
}
};
......
......@@ -41,7 +41,8 @@
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "collective/fmha_common.hpp"
#include "utils.h" // for IS_SM100
#include "../collective/fmha_common.hpp"
#include <cmath>
......@@ -1499,6 +1500,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
#if IS_SM100
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_role(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
......@@ -1823,6 +1825,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
/* no-op */
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n");
}
#endif
}
static dim3 get_block_shape() {
......
......@@ -41,7 +41,8 @@
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "collective/fmha_common.hpp"
#include "utils.h" // for IS_SM100
#include "../collective/fmha_common.hpp"
#include <cmath>
......@@ -1492,6 +1493,7 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
#if IS_SM100
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_role(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
......@@ -1816,6 +1818,11 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
/* no-op */
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n");
}
#endif
}
static dim3 get_block_shape() {
......
......@@ -37,11 +37,12 @@
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/arch/tmem_allocator_sm100.hpp"
#include "kernel/fmha_options.hpp"
#include "kernel/fmha_tile_scheduler.hpp"
#include "kernel/fmha_causal_tile_scheduler.hpp"
#include "collective/fmha_fusion.hpp"
#include "collective/fmha_common.hpp"
#include "utils.h" // for IS_SM100
#include "../kernel/fmha_options.hpp"
#include "../kernel/fmha_tile_scheduler.hpp"
#include "../kernel/fmha_causal_tile_scheduler.hpp"
#include "../collective/fmha_fusion.hpp"
#include "../collective/fmha_common.hpp"
namespace cutlass::fmha::kernel {
......@@ -251,6 +252,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if IS_SM100
TileScheduler tile_scheduler{params.tile_scheduler};
......@@ -629,6 +631,11 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
/* no-op, donate regs and exit */
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n");
}
#endif
}
};
......
......@@ -8,6 +8,4 @@ static constexpr int PAGE_BLOCK_SIZE = 64;
static constexpr int HEAD_DIM_K = 576;
static constexpr int HEAD_DIM_V = 512;
static constexpr int FIXED_OVERHEAD_NUM_BLOCKS = 5;
}
#include <cutlass/cutlass.h>
#include "params.h"
#include "utils.h"
#include "params.h"
#include "config.h"
#include "traits.h"
using namespace cute;
using cutlass::arch::NamedBarrier;
namespace sm90 {
// Here we use MAX_INIT_VAL_SM to initialize sM, and MAX_INIT_VAL for masking
// The reason is that, we need to calculate new_max = max(sM(row_idx), cur_max*scale_softmax_log2)
// so we must guarantee that MAX_INIT_VAL*scale_softmax_log2 < MAX_INIT_VAL_SM
static constexpr float MAX_INIT_VAL_SM = -1e30f;
static constexpr float MAX_INIT_VAL = -1e33f;
__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
......@@ -756,7 +758,7 @@ __forceinline__ __device__ void wg0_subroutine(
TMABarrier barriers_K1[9],
bool &cur_phase_K0,
const TMAParams &tma_params,
const Flash_fwd_mla_params &params,
const DecodingParams &params,
int* block_table_ptr,
int seqlen_k,
int block_idx,
......@@ -868,7 +870,7 @@ __forceinline__ __device__ void wg1_subroutine(
TMABarrier barriers_K1[9],
bool &cur_phase_K1,
const TMAParams &tma_params,
const Flash_fwd_mla_params &params,
const DecodingParams &params,
int* block_table_ptr,
int seqlen_k,
int block_idx,
......@@ -943,7 +945,7 @@ __forceinline__ __device__ void wg1_subroutine(
}
// A helper function for determining the length of the causal mask for one q token
__forceinline__ __device__ int get_mask_len(const Flash_fwd_mla_params &params, int m_block_idx, int local_seq_q_idx) {
__forceinline__ __device__ int get_mask_len(const DecodingParams &params, int m_block_idx, int local_seq_q_idx) {
int global_seq_q_idx = m_block_idx*Config::BLOCK_SIZE_M + local_seq_q_idx;
if (global_seq_q_idx < params.q_seq_per_hk) {
int s_q_idx = global_seq_q_idx / params.q_head_per_hk;
......@@ -956,7 +958,7 @@ __forceinline__ __device__ int get_mask_len(const Flash_fwd_mla_params &params,
template<typename T, typename TmaParams>
__global__ void __launch_bounds__(T::NUM_THREADS, 1, 1)
flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params, __grid_constant__ const TmaParams tma_params) {
flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __grid_constant__ const TmaParams tma_params) {
// grid shape: [
// num_m_blocks (=ceil_div(seqlen_q_ori*(num_q_heads//num_kv_heads))),
// num_kv_heads,
......@@ -966,6 +968,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
// If is_no_split is True, then this request is exclusively assigned to this sm_part, so we shall write the result directly into params.o_ptr and params.softmax_lse_ptr. Otherwise, write to oaccum_ptr and softmax_lseaccum_ptr, with the corresponding split idx being (n_split_idx + num_splits_ptr[batch_idx])
// For the complete schedule of the kernel, please read our deep-dive write-up (link can be found in the README.md file).
#if IS_SM90
const int m_block_idx = blockIdx.x;
const int k_head_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
......@@ -1018,11 +1021,11 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
// We don't use __ldg here, otherwise NVCC (ptxas, in particular) will do instruction reorder and place __ldg (LDG.E.128.CONSTANT in SASS) in front of cudaGridDependencySynchronize() (ACQBULK in SASS), leading to data race.
int4 tile_scheduler_metadata = *(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
int4 tile_scheduler_metadata = *(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
int begin_idx = tile_scheduler_metadata.x;
int begin_seqlen = tile_scheduler_metadata.y;
int sched_begin_block_idx = tile_scheduler_metadata.y;
int end_idx = tile_scheduler_metadata.z;
int end_seqlen = tile_scheduler_metadata.w;
int sched_end_block_idx = tile_scheduler_metadata.w;
if (begin_idx >= params.b) return;
int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4);
......@@ -1034,9 +1037,9 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
constexpr int kBlockN = T::PAGE_BLOCK_SIZE;
const int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0;
int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx);
const int start_block_idx = batch_idx == begin_idx ? begin_seqlen / kBlockN : 0;
int end_block_idx = batch_idx == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
const bool is_no_split = start_block_idx == 0 && end_block_idx == cute::ceil_div(seqlen_k, kBlockN);
const int start_block_idx = batch_idx == begin_idx ? sched_begin_block_idx : 0;
int end_block_idx = batch_idx == end_idx ? sched_end_block_idx : cute::ceil_div(seqlen_k, kBlockN);
const bool is_no_split = __ldg(params.num_splits_ptr + batch_idx + 1) - __ldg(params.num_splits_ptr + batch_idx) == 1;
int rRightBorderForQSeq[2];
if (params.is_causal) {
......@@ -1057,7 +1060,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
// Besides, a token may have some extra masks other than the common mask. We use rRightBorderForQSeq to denote it, which means the right border of the k-sequence for the particular q token. In this way, (seqlen_k-common_mask_len) - rRightBorderForQSeq < 64 holds, which means that we only need to apply the causal mask to the last two KV blocks
// NOTE This may lead to start_block_idx >= end_block_idx which needs some special handling
int common_mask_len = get_mask_len(params, m_block_idx, T::BLOCK_SIZE_M-1);
end_block_idx = batch_idx == end_idx ? cute::ceil_div(min(end_seqlen, seqlen_k-common_mask_len), kBlockN) : cute::ceil_div(seqlen_k-common_mask_len, kBlockN);
int last_block_in_seq = cute::ceil_div(seqlen_k-common_mask_len, kBlockN);
end_block_idx = batch_idx == end_idx ? min(sched_end_block_idx, last_block_in_seq) : last_block_in_seq;
CUTLASS_PRAGMA_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
......@@ -1267,11 +1271,16 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
if (batch_idx != end_idx)
__syncthreads();
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90");
}
#endif
}
template<typename InputT>
void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params &params, cudaStream_t stream) {
void run_flash_splitkv_mla_kernel(DecodingParams &params, cudaStream_t stream) {
using T = Traits<InputT>;
auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b);
auto tma_Q = cute::make_tma_copy(
......@@ -1347,8 +1356,10 @@ void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params &params, cudaStream_t str
CHECK_CUDA_KERNEL_LAUNCH();
}
template void run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
template void run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(DecodingParams &params, cudaStream_t stream);
#ifndef FLASH_MLA_DISABLE_FP16
template void run_flash_splitkv_mla_kernel<cutlass::half_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
template void run_flash_splitkv_mla_kernel<cutlass::half_t>(DecodingParams &params, cudaStream_t stream);
#endif
}
......@@ -2,5 +2,9 @@
#include "params.h"
namespace sm90 {
template<typename InputT>
void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params &params, cudaStream_t stream);
void run_flash_splitkv_mla_kernel(DecodingParams &params, cudaStream_t stream);
}
#pragma once
#include <cutlass/numeric_types.h>
#include <cutlass/arch/barrier.h>
#include <cute/tensor.hpp>
using bf16 = cutlass::bfloat16_t;
using fp8 = cutlass::float_e4m3_t;
using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
using namespace cute;
static constexpr int NUM_THREADS = 128*3;
static constexpr int BLOCK_M = 64;
static constexpr int TOPK_BLOCK_SIZE = 64;
static constexpr int PAGE_BLOCK_SIZE = 64;
static constexpr int QUANT_TILE_SIZE = 128;
static constexpr int HEAD_DIM_K = 576;
static constexpr int HEAD_DIM_V = 512;
static constexpr int HEAD_DIM_NOPE = HEAD_DIM_V;
static constexpr int HEAD_DIM_ROPE = HEAD_DIM_K - HEAD_DIM_V;
static constexpr int NUM_SCALES = HEAD_DIM_NOPE / QUANT_TILE_SIZE;
static constexpr int NUM_BYTES_PER_TOKEN = HEAD_DIM_NOPE + NUM_SCALES*sizeof(float) + HEAD_DIM_ROPE*sizeof(bf16);
static constexpr int NUM_K_BUFS = 2;
using SmemLayoutQTile = decltype(tile_to_shape(
GMMA::Layout_SW128_Atom<bf16, GMMA::Major::K>{},
Shape<Int<BLOCK_M>, Int<64>>{}
));
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(tile_to_shape(
SmemLayoutQTile{},
Shape<Int<BLOCK_M>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
));
using SmemLayoutQ = SmemLayoutQTiles<9>;
using SmemLayoutKTile = decltype(tile_to_shape(
GMMA::Layout_INTER_Atom<bf16, GMMA::Major::K>{},
Shape<Int<TOPK_BLOCK_SIZE>, _64>{},
Step<_1, _2>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles = decltype(tile_to_shape(
SmemLayoutKTile{},
Shape<Int<TOPK_BLOCK_SIZE>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed = decltype(composition(
SmemLayoutKTiles<NUM_TILES>{},
Layout<Shape<Int<64*NUM_TILES>, Int<TOPK_BLOCK_SIZE>>, Stride<Int<TOPK_BLOCK_SIZE>, _1>>{}
));
using SmemLayoutOBuf = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{}
));
using SmemLayoutOAccumBuf = Layout<
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>,
Stride<Int<520>, _1> // We use stride = 520 here to avoid bank conflict
>;
using SmemLayoutK = SmemLayoutKTiles<9>;
using SmemLayoutV = SmemLayoutKTilesTransposed<8>;
using SmemLayoutHalfV = SmemLayoutKTilesTransposed<4>;
using SmemLayoutS = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{}
));
struct SharedMemoryPlan {
array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
union {
array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_K_BUFS];
array_aligned<bf16, cosize_v<SmemLayoutOBuf>> oBuf;
array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> oAccumBuf;
} u;
array_aligned<bf16, cosize_v<SmemLayoutS>> s;
bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE];
float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M];
transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS];
};
template<
typename Shape_Q, typename TMA_Q,
typename Shape_O, typename TMA_O
>
struct TmaParams {
Shape_Q shape_Q; TMA_Q tma_Q;
Shape_O shape_O; TMA_O tma_O;
};
using TiledMMA_QK = decltype(make_tiled_mma(
GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_QK_rQ = decltype(make_tiled_mma(
GMMA::MMA_64x64x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::K>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{}
));
#pragma once
#include <cuda_fp8.h>
#include <cuda_bf16.h>
struct fp8x8 {
__nv_fp8x4_e4m3 lo;
__nv_fp8x4_e4m3 hi;
};
struct fp8x16 {
fp8x8 lo;
fp8x8 hi;
};
struct bf16x8 {
__nv_bfloat162 a, b, c, d;
};
__device__ __forceinline__
bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const float &scale) {
__nv_bfloat162 scale_bf162 = __float2bfloat162_rn(scale);
#define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \
{ \
float4 fp32x4 = (float4)(FP8x4); \
OUTPUT_BF16_LO = __float22bfloat162_rn({fp32x4.x, fp32x4.y})*scale_bf162; \
OUTPUT_BF16_HI = __float22bfloat162_rn({fp32x4.z, fp32x4.w})*scale_bf162; \
}
bf16x8 result;
DEQUANT_FP8x4(result.a, result.b, inputs.lo);
DEQUANT_FP8x4(result.c, result.d, inputs.hi);
return result;
}
enum class L1CacheHint {
NO_ALLOCATE,
EVICT_FIRST,
EVICT_NORMAL,
EVICT_LAST
};
enum class L2PrefetchHint {
B64,
B128,
B256
};
template<
typename T,
L1CacheHint l1_cache_hint,
L2PrefetchHint l2_prefetch_hint
>
__device__ __forceinline__
T load_128b_from_gmem(const void* addr) {
static_assert(sizeof(T) == 128/8);
int4 ret;
#define EXEC(L1_HINT_STR, L2_HINT_STR) { \
asm volatile("ld.global.nc.L1::" L1_HINT_STR ".L2::" L2_HINT_STR ".v4.s32 {%0, %1, %2, %3}, [%4];" \
: "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) \
: "l"(addr)); \
}
#define DISPATCH_L2(L1_HINT_STR) { \
if constexpr(l2_prefetch_hint == L2PrefetchHint::B64) \
EXEC(L1_HINT_STR, "64B") \
else if constexpr(l2_prefetch_hint == L2PrefetchHint::B128) \
EXEC(L1_HINT_STR, "128B") \
else if constexpr(l2_prefetch_hint == L2PrefetchHint::B256) \
EXEC(L1_HINT_STR, "256B") \
}
if constexpr(l1_cache_hint == L1CacheHint::NO_ALLOCATE)
DISPATCH_L2("no_allocate")
else if constexpr(l1_cache_hint == L1CacheHint::EVICT_FIRST)
DISPATCH_L2("evict_first")
else if constexpr(l1_cache_hint == L1CacheHint::EVICT_NORMAL)
DISPATCH_L2("evict_normal")
else if constexpr(l1_cache_hint == L1CacheHint::EVICT_LAST)
DISPATCH_L2("evict_last")
#undef EXEC
#undef DISPATCH_L2
return *reinterpret_cast<T*>(&ret);
}
#pragma once
#include "named_barriers.h"
// Store O / OAccum
template<
bool IS_NO_SPLIT,
typename TMAParams,
typename Tensor0,
typename Tensor1,
typename Tensor2,
typename Tensor3
>
__forceinline__ __device__ void store_o(
Tensor0 &rO, // ((2, 2, 32), 1, 1)
Tensor1 &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V)
Tensor2 &sOutputBuf,
Tensor3 &sOutputAccumBuf,
float rL[2],
TMAParams &tma_params,
int batch_idx,
int s_q_idx,
int head_block_idx,
int num_valid_seq_q,
int warpgroup_idx,
int idx_in_warpgroup
) {
using cutlass::arch::NamedBarrier;
if constexpr (IS_NO_SPLIT) {
// Should convert the output to bfloat16 / float16, and save it to O
Tensor rOb = make_tensor_like<bf16>(rO);
CUTLASS_PRAGMA_UNROLL
for (int idx = 0; idx < size(rO); ++idx) {
rOb(idx) = (bf16)(rO(idx) / rL[idx%4 >= 2]);
}
Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx));
TiledCopy r2s_tiled_copy = make_tiled_copy_C(
Copy_Atom<SM90_U32x4_STSM_N, bf16>{},
TiledMMA_PV_LocalP{}
);
ThrCopy r2s_thr_copy = r2s_tiled_copy.get_slice(idx_in_warpgroup);
Tensor r2s_thr_copy_rOb = r2s_thr_copy.retile_S(rOb);
Tensor r2s_thr_copy_sMyOutputBuf = r2s_thr_copy.partition_D(sMyOutputBuf);
cute::copy(r2s_tiled_copy, r2s_thr_copy_rOb, r2s_thr_copy_sMyOutputBuf);
cutlass::arch::fence_view_async_shared();
NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready);
if (threadIdx.x == 0) {
Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx, batch_idx);
auto thr_tma = tma_params.tma_O.get_slice(_0{});
Tensor my_tma_gO = flat_divide(tma_gO, Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{})(_, _, head_block_idx, _0{});
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(sOutputBuf),
thr_tma.partition_D(my_tma_gO)
);
cute::tma_store_arrive();
}
} else {
// Should save the result to OAccum
CUTLASS_PRAGMA_UNROLL
for (int idx = 0; idx < size(rO); idx += 2) {
int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0);
int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8;
*(float2*)(&(sOutputAccumBuf(row, col))) = float2 {
rO(idx) / rL[idx%4 >= 2],
rO(idx+1) / rL[idx%4 >= 2],
};
}
cutlass::arch::fence_view_async_shared();
NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready);
if (elect_one_sync()) {
CUTLASS_PRAGMA_UNROLL
for (int local_row = 0; local_row < BLOCK_M / (256/32); ++local_row) {
int row = local_row * (256/32) + (threadIdx.x / 32);
if (row < num_valid_seq_q) {
SM90_BULK_COPY_S2G::copy(&sOutputAccumBuf(row, _0{}), &gOorAccum(row, _0{}), HEAD_DIM_V*sizeof(float));
}
}
cute::tma_store_arrive();
}
}
}
#pragma once
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {
int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);
return row_idx;
}
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
template<
typename TMA,
typename Tensor0,
typename Tensor1
>
CUTE_DEVICE
void launch_tma_copy(
const TMA &tma_copy,
const Tensor0 &src,
Tensor1 &dst,
transac_bar_t &bar,
const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL,
const uint16_t &multicast_mask = 0
) {
auto thr_tma = tma_copy.get_slice(_0{});
cute::copy(
tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), multicast_mask, cache_hint),
thr_tma.partition_S(src),
thr_tma.partition_D(dst)
);
}
template<typename T>
CUTE_DEVICE
static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) {
long2 data_long2 = *reinterpret_cast<const long2*>(&data);
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr);
asm volatile (
"st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n"
:
: "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr)
);
}
static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字
template<typename T>
CUTE_DEVICE
T* get_peer_addr(const T* p) {
return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment