Commit 3524e13c authored by Tri Dao's avatar Tri Dao
Browse files

Update to Cutlass 3.1

parent 364a5b4a
Subproject commit c4f6b8c6bc94ff69048492fb34df0dfaf1983933 Subproject commit 6f47420213f757831fae65c686aa471749fa8d60
...@@ -147,14 +147,16 @@ inline __device__ void compute_dot_do_o(const Params &params) { ...@@ -147,14 +147,16 @@ inline __device__ void compute_dot_do_o(const Params &params) {
Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum), Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
Shape<Int<kBlockM>>{}, Stride<_1>{}); Shape<Int<kBlockM>>{}, Stride<_1>{});
auto gmem_thr_copy_dO = typename Kernel_traits::GmemTiledCopydO{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO;
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
// TODO: careful, we're zeroing out dQaccum with type float4, but when // TODO: careful, we're zeroing out dQaccum with type float4, but when
// we do atomicAdds, we use type float. The layouts are different. Check this. // we do atomicAdds, we use type float. The layouts are different. Check this.
auto gmem_thr_copy_dQ_accum = typename Kernel_traits::GmemTiledCopydQaccum{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
Tensor tdQgdQaccum = gmem_thr_copy_dQ_accum.partition_D(gdQaccum); Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
Tensor cdO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cdO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO); Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO);
...@@ -168,10 +170,10 @@ inline __device__ void compute_dot_do_o(const Params &params) { ...@@ -168,10 +170,10 @@ inline __device__ void compute_dot_do_o(const Params &params) {
Tensor tdOrdO = make_fragment_like(tdOgdO); Tensor tdOrdO = make_fragment_like(tdOgdO);
Tensor tdOrO = make_fragment_like(tdOgO); Tensor tdOrO = make_fragment_like(tdOgO);
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>( flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
); );
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>( flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
); );
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here, // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
...@@ -181,7 +183,7 @@ inline __device__ void compute_dot_do_o(const Params &params) { ...@@ -181,7 +183,7 @@ inline __device__ void compute_dot_do_o(const Params &params) {
if (Clear_dQaccum) { if (Clear_dQaccum) {
Tensor zero = make_fragment_like(tdQgdQaccum); Tensor zero = make_fragment_like(tdQgdQaccum);
clear(zero); clear(zero);
copy(gmem_thr_copy_dQ_accum, zero, tdQgdQaccum); cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum);
} }
} }
...@@ -213,13 +215,14 @@ inline __device__ void clear_dKVaccum(const Params &params) { ...@@ -213,13 +215,14 @@ inline __device__ void clear_dKVaccum(const Params &params) {
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum), Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{}); Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
auto gmem_thr_copy_dKV_accum = typename Kernel_traits::GmemTiledCopydQaccum{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
Tensor tdKgdKaccum = gmem_thr_copy_dKV_accum.partition_D(gdKaccum); auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
Tensor tdVgdVaccum = gmem_thr_copy_dKV_accum.partition_D(gdVaccum); Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);
Tensor zero = make_fragment_like(tdKgdKaccum); Tensor zero = make_fragment_like(tdKgdKaccum);
clear(zero); clear(zero);
copy(gmem_thr_copy_dKV_accum, zero, tdKgdKaccum); cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum);
copy(gmem_thr_copy_dKV_accum, zero, tdVgdVaccum); cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -264,22 +267,25 @@ inline __device__ void convert_dQ(const Params &params) { ...@@ -264,22 +267,25 @@ inline __device__ void convert_dQ(const Params &params) {
Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutdQ{}); typename Kernel_traits::SmemLayoutdQ{});
auto gmem_thr_copy_dQ = typename Kernel_traits::GmemTiledCopydQ{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
auto gmem_thr_copy_dQ_accum = typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd{}.get_thread_slice(tidx); auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum;
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
typename Kernel_traits::TiledMmadQ tiled_mma_dq; typename Kernel_traits::TiledMmadQ tiled_mma_dq;
auto smem_thr_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq).get_thread_slice(tidx); auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
Tensor tdQgdQaccum = gmem_thr_copy_dQ_accum.partition_S(gdQaccum); Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum);
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum); Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum);
copy(gmem_thr_copy_dQ_accum, tdQgdQaccum, tdQrdQaccum); cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);
#pragma unroll #pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { for (int i = 0; i < size(acc_dq); ++i) {
acc_dq(i) = tdQrdQaccum(i) * params.scale_softmax_rp_dropout; acc_dq(i) = tdQrdQaccum(i) * params.scale_softmax_rp_dropout;
...@@ -287,10 +293,10 @@ inline __device__ void convert_dQ(const Params &params) { ...@@ -287,10 +293,10 @@ inline __device__ void convert_dQ(const Params &params) {
// Convert acc_dq from fp32 to fp16 // Convert acc_dq from fp32 to fp16
Tensor rdQ = flash::convert_type<Element>(acc_dq); Tensor rdQ = flash::convert_type<Element>(acc_dq);
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
copy(smem_thr_copy_dQ, taccdQrdQ, taccdQsdQ); cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
__syncthreads(); __syncthreads();
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ)); Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
copy(gmem_thr_copy_dQ, tdQsdQ, tdQrdQ); cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
...@@ -299,7 +305,7 @@ inline __device__ void convert_dQ(const Params &params) { ...@@ -299,7 +305,7 @@ inline __device__ void convert_dQ(const Params &params) {
for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
); );
} }
...@@ -354,11 +360,14 @@ inline __device__ void convert_dKV(const Params &params) { ...@@ -354,11 +360,14 @@ inline __device__ void convert_dKV(const Params &params) {
typename Kernel_traits::SmemLayoutdKV{}); typename Kernel_traits::SmemLayoutdKV{});
Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
auto gmem_thr_copy_dKV = typename Kernel_traits::GmemTiledCopydQ{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV_accum = typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd{}.get_thread_slice(tidx); auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum;
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
typename Kernel_traits::TiledMmadKV tiled_mma_dkv; typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
auto smem_thr_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv).get_thread_slice(tidx); auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
...@@ -366,8 +375,8 @@ inline __device__ void convert_dKV(const Params &params) { ...@@ -366,8 +375,8 @@ inline __device__ void convert_dKV(const Params &params) {
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
Tensor tdKgdKaccum = gmem_thr_copy_dKV_accum.partition_S(gdKaccum); Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum);
Tensor tdVgdVaccum = gmem_thr_copy_dKV_accum.partition_S(gdVaccum); Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum);
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
...@@ -376,8 +385,8 @@ inline __device__ void convert_dKV(const Params &params) { ...@@ -376,8 +385,8 @@ inline __device__ void convert_dKV(const Params &params) {
Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum); Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum);
Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum); Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum);
copy(gmem_thr_copy_dKV_accum, tdKgdKaccum, tdKrdKaccum); cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum);
copy(gmem_thr_copy_dKV_accum, tdVgdVaccum, tdVrdVaccum); cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum);
#pragma unroll #pragma unroll
for (int i = 0; i < size(acc_dk); ++i) { for (int i = 0; i < size(acc_dk); ++i) {
acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout; acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout;
...@@ -391,13 +400,13 @@ inline __device__ void convert_dKV(const Params &params) { ...@@ -391,13 +400,13 @@ inline __device__ void convert_dKV(const Params &params) {
Tensor rdV = flash::convert_type<Element>(acc_dv); Tensor rdV = flash::convert_type<Element>(acc_dv);
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
copy(smem_thr_copy_dKV, taccdKrdK, taccdKsdK); cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
copy(smem_thr_copy_dKV, taccdVrdV, taccdVsdV); cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
__syncthreads(); __syncthreads();
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK)); Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV)); Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
copy(gmem_thr_copy_dKV, tdKsdK, tdKrdK); cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
copy(gmem_thr_copy_dKV, tdVsdV, tdVrdV); cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
Tensor cdKV = make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cdKV = make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
...@@ -406,10 +415,10 @@ inline __device__ void convert_dKV(const Params &params) { ...@@ -406,10 +415,10 @@ inline __device__ void convert_dKV(const Params &params) {
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
} }
...@@ -511,20 +520,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -511,20 +520,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast<float2 *>((sP.data() + cute::max(size(sP), size(sdQ))).get())), Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast<float2 *>((sP.data() + cute::max(size(sP), size(sdQ))).get())),
Shape<Int<Kernel_traits::kSmemdPsumCount / 2>>{}); Shape<Int<Kernel_traits::kSmemdPsumCount / 2>>{});
auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
using GmemTiledCopydO = std::conditional_t< using GmemTiledCopydO = std::conditional_t<
Is_first, Is_first,
typename Kernel_traits::GmemTiledCopydO, typename Kernel_traits::GmemTiledCopydO,
typename Kernel_traits::GmemTiledCopyQKV typename Kernel_traits::GmemTiledCopyQKV
>; >;
auto gmem_thr_copy_dO = GmemTiledCopydO{}.get_thread_slice(tidx); GmemTiledCopydO gmem_tiled_copy_dO;
auto gmem_thr_copy_dQ = typename Kernel_traits::GmemTiledCopydQ{}.get_thread_slice(tidx); auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
using GmemLayoutAtomdQaccum = std::conditional_t< using GmemLayoutAtomdQaccum = std::conditional_t<
!Seq_parallel, !Seq_parallel,
typename Kernel_traits::GmemTiledCopydQaccum, typename Kernel_traits::GmemTiledCopydQaccum,
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd
>; >;
auto gmem_thr_copy_dQ_accum = GmemLayoutAtomdQaccum{}.get_thread_slice(tidx); GmemLayoutAtomdQaccum gmem_tiled_copy_dQaccum;
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
...@@ -537,7 +550,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -537,7 +550,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
Tensor tdQgdQaccum = gmem_thr_copy_dQ_accum.partition_D(gdQaccum); Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
// if (cute::thread0()) { print(tdQgdQaccum.layout()); printf("\n"); } // if (cute::thread0()) { print(tdQgdQaccum.layout()); printf("\n"); }
// __syncthreads(); // __syncthreads();
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx < 64) { // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx < 64) {
...@@ -570,12 +583,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -570,12 +583,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Copy Atom retiling // Copy Atom retiling
// //
auto smem_thr_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx); auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ); Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ);
Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO); Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO);
// auto smem_thr_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx); // auto smem_thr_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx);
auto smem_thr_copy_KV = make_tiled_copy_B_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx); auto smem_tiled_copy_KV = make_tiled_copy_B_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_KV.partition_S(sK); Tensor tSsK = smem_thr_copy_KV.partition_S(sK);
// if (cute::thread(0, 0) && n_block == 0) { printf("sK layout: "); print(sK.layout()); printf("\n"); } // if (cute::thread(0, 0) && n_block == 0) { printf("sK layout: "); print(sK.layout()); printf("\n"); }
// if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); } // if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); }
...@@ -584,7 +599,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -584,7 +599,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Partition sP and sdS to match the accumulator partitioning // Partition sP and sdS to match the accumulator partitioning
// This has to be tiled_mma_sdp, not tiled_mma_dkv // This has to be tiled_mma_sdp, not tiled_mma_dkv
// auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx); // auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx);
auto smem_thr_copy_PdS = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx); auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp);
auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx);
Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// if (cute::thread(0, 0) && n_block == 0) { printf("sP layout: "); print(sP.layout()); printf("\n"); } // if (cute::thread(0, 0) && n_block == 0) { printf("sP layout: "); print(sP.layout()); printf("\n"); }
// if (cute::thread(0, 0) && n_block == 0) { print(tPsP.layout()); printf("\n"); } // if (cute::thread(0, 0) && n_block == 0) { print(tPsP.layout()); printf("\n"); }
...@@ -593,21 +609,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -593,21 +609,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// } // }
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N)
auto smem_thr_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv).get_thread_slice(tidx); auto smem_tiled_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);
auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(tidx);
Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt); Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt);
Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt); Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt);
auto smem_thr_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv).get_thread_slice(tidx); auto smem_tiled_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);
auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx);
Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt); Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt);
Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt); Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt);
auto smem_thr_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq).get_thread_slice(tidx); auto smem_tiled_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq);
auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(tidx);
Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS); Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS);
auto smem_thr_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq).get_thread_slice(tidx); auto smem_tiled_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq);
auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(tidx);
Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt); Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt);
auto smem_thr_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq).get_thread_slice(tidx); auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// //
...@@ -655,7 +676,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -655,7 +676,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv), Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dv_row_stride, _1{})); make_stride(params.dv_row_stride, _1{}));
auto gmem_thr_copy_dKV = typename Kernel_traits::GmemTiledCopydKV{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK)); Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
...@@ -669,10 +691,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -669,10 +691,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
return; return;
} }
...@@ -688,7 +710,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -688,7 +710,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (Kernel_traits::Is_V_in_regs) { if (Kernel_traits::Is_V_in_regs) {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
flash::cp_async_fence(); flash::cp_async_fence();
} }
...@@ -698,18 +720,18 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -698,18 +720,18 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (!Is_first) { if (!Is_first) {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
); );
} else { } else {
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
); );
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
); );
} }
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
); );
Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
...@@ -726,23 +748,23 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -726,23 +748,23 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
} }
// Tensor tKrK = make_fragment_like(tKsK); // Tensor tKrK = make_fragment_like(tKsK);
// // copy(gmem_thr_copy_QKV, tKgK(_, _, _, 0), tKrK); // // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK);
// copy(gmem_thr_copy_QKV, tKgK, tKrK); // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK);
// // if (cute::thread(1, 0)) { print(tKrK); } // // if (cute::thread(1, 0)) { print(tKrK); }
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
if (!Kernel_traits::Is_V_in_regs) { if (!Kernel_traits::Is_V_in_regs) {
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
} }
flash::cp_async_fence(); flash::cp_async_fence();
// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); } // if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
if (Is_first) { if (Is_first) {
copy(tdOrdO, tdOsdO); cute::copy(tdOrdO, tdOsdO);
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum, sdPsum, dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum, sdPsum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
} }
...@@ -752,7 +774,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -752,7 +774,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
__syncthreads(); __syncthreads();
Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV); Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV);
CUTE_STATIC_ASSERT_V(size<1>(tdPsV) == size<1>(tdPrV_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tdPsV) == size<1>(tdPrV_copy_view)); // M
copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view); cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
} }
auto seed = params.rng_state[0]; auto seed = params.rng_state[0];
...@@ -775,10 +797,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -775,10 +797,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK); // Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK);
// #pragma unroll // #pragma unroll
// for (int k = 0; k < size<2>(tSrK_copy_view); ++k) { // for (int k = 0; k < size<2>(tSrK_copy_view); ++k) {
// copy(smem_thr_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k)); // cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k));
// } // }
// if (cute::thread0()) { print(tSrK); } // if (cute::thread0()) { print(tSrK); }
flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, smem_thr_copy_QdO, smem_thr_copy_KV); flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
...@@ -827,7 +850,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -827,7 +850,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8. // if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8.
Tensor tPrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMmaSdP>(rP.layout())); Tensor tPrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMmaSdP>(rP.layout()));
Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N)
copy(smem_thr_copy_PdS, tPaP, tPsP); cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);
// if (cute::thread0()) { print(tPaP); } // if (cute::thread0()) { print(tPaP); }
// __syncthreads(); // __syncthreads();
// if (cute::thread0()) { print(sP); } // if (cute::thread0()) { print(sP); }
...@@ -850,7 +873,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -850,7 +873,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// if (cute::thread0()) { print(dP_sum); } // if (cute::thread0()) { print(dP_sum); }
flash::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>( flash::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(
acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp, smem_thr_copy_QdO, smem_thr_copy_KV acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV
); );
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
...@@ -877,7 +901,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -877,7 +901,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
make_layout(get<0>(acc_dq.layout()), make_layout(get<0>(acc_dq.layout()),
get<2>(acc_dq.layout()), get<2>(acc_dq.layout()),
get<1>(acc_dq.layout()))); get<1>(acc_dq.layout())));
copy(gmem_thr_copy_dQ_accum, tdQgdQaccum, acc_dq_reshaped); cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped);
} }
if (Double_buffer && m_block > m_block_min) { if (Double_buffer && m_block > m_block_min) {
...@@ -887,7 +911,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -887,7 +911,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
tSsQ.data() = tSsQ.data() + sQ_offset; tSsQ.data() = tSsQ.data() + sQ_offset;
// Advance gQ // Advance gQ
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
flash::cp_async_fence(); flash::cp_async_fence();
} }
...@@ -896,7 +920,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -896,7 +920,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor tdSrdS = flash::convert_type<Element>(dS_reshaped); Tensor tdSrdS = flash::convert_type<Element>(dS_reshaped);
// if (cute::thread0()) { print(tPrP); } // if (cute::thread0()) { print(tPrP); }
Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N)
copy(smem_thr_copy_PdS, tdSadS, tdSsdS); cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
__syncthreads(); __syncthreads();
// Layout p_l = tPrP.layout(); // Layout p_l = tPrP.layout();
...@@ -904,7 +928,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -904,7 +928,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// flash::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); // flash::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
// Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
// flash::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); // flash::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, smem_thr_copy_PdSt, smem_thr_copy_QdOt); flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
// if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); } // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
// if (cute::thread0()) { print(acc_dv); } // if (cute::thread0()) { print(acc_dv); }
...@@ -915,15 +940,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -915,15 +940,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride));
if (Is_first) { if (Is_first) {
tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride)); tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ);
} else { } else {
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ);
flash::cp_async_fence(); flash::cp_async_fence();
} }
} }
flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq, smem_thr_copy_dS, smem_thr_copy_Kt); flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);
// if (cute::thread0()) { print(acc_dq); } // if (cute::thread0()) { print(acc_dq); }
if (m_block > m_block_min) { if (m_block > m_block_min) {
...@@ -945,7 +971,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -945,7 +971,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
get<2>(acc_dq.layout()), get<2>(acc_dq.layout()),
get<1>(acc_dq.layout()))); get<1>(acc_dq.layout())));
if (!Seq_parallel) { if (!Seq_parallel) {
copy(gmem_thr_copy_dQ_accum, acc_dq_reshaped, tdQgdQaccum); cute::copy(gmem_tiled_copy_dQaccum, acc_dq_reshaped, tdQgdQaccum);
} else { } else {
// if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); } // if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); }
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
...@@ -958,10 +984,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -958,10 +984,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Convert acc_dq from fp32 to fp16 // Convert acc_dq from fp32 to fp16
Tensor rdQ = flash::convert_type<Element>(acc_dq); Tensor rdQ = flash::convert_type<Element>(acc_dq);
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
copy(smem_thr_copy_dQ, taccdQrdQ, taccdQsdQ); cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
} }
flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv, smem_thr_copy_PdSt, smem_thr_copy_QdOt); flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
// if (cute::thread0()) { print(acc_dk); } // if (cute::thread0()) { print(acc_dk); }
if (Double_buffer) { // Double buffer for sQ if (Double_buffer) { // Double buffer for sQ
tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ)); tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ));
...@@ -970,12 +997,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -970,12 +997,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
__syncthreads(); __syncthreads();
// Advance gQ // Advance gQ
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
flash::cp_async_fence(); flash::cp_async_fence();
} }
if (Is_first && m_block > m_block_min) { if (Is_first && m_block > m_block_min) {
copy(tdOrdO, tdOsdO); cute::copy(tdOrdO, tdOsdO);
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum, sdPsum, dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum, sdPsum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
} }
...@@ -983,14 +1010,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -983,14 +1010,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (Is_last) { if (Is_last) {
__syncthreads(); __syncthreads();
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ)); Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
copy(gmem_thr_copy_dQ, tdQsdQ, tdQrdQ); cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride)); tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride));
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
#pragma unroll #pragma unroll
for (int m = 0; m < size<1>(tdQgdQ); ++m) { for (int m = 0; m < size<1>(tdQgdQ); ++m) {
if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) { if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
copy(gmem_thr_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _)); cute::copy(gmem_tiled_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _));
} }
} }
} }
...@@ -1014,7 +1041,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -1014,7 +1041,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
// Partition sdV and sdK to match the accumulator partitioning // Partition sdV and sdK to match the accumulator partitioning
auto smem_thr_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv).get_thread_slice(tidx); auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
...@@ -1026,8 +1054,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -1026,8 +1054,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// If Is_last, there's already a __syncthreads() at the end of the loop. // If Is_last, there's already a __syncthreads() at the end of the loop.
if (!Is_last) { __syncthreads(); } if (!Is_last) { __syncthreads(); }
copy(smem_thr_copy_dKV, taccdKrdK, taccdKsdK); cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
copy(smem_thr_copy_dKV, taccdVrdV, taccdVsdV); cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
...@@ -1040,7 +1068,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -1040,7 +1068,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dv_row_stride, _1{})); make_stride(params.dv_row_stride, _1{}));
auto gmem_thr_copy_dKV = typename Kernel_traits::GmemTiledCopydKV{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
...@@ -1048,9 +1077,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -1048,9 +1077,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
__syncthreads(); __syncthreads();
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK)); Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
copy(gmem_thr_copy_dKV, tdKsdK, tdKrdK); cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV)); Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
copy(gmem_thr_copy_dKV, tdVsdV, tdVrdV); cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
Tensor cdKV = make_identity_tensor(make_shape(size<0>(sdK), size<1>(sdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor cdKV = make_identity_tensor(make_shape(size<0>(sdK), size<1>(sdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK))); Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
...@@ -1058,10 +1087,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -1058,10 +1087,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
} }
...@@ -1163,9 +1192,12 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1163,9 +1192,12 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast<ElementAccum *>(sdS.data().get())), Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast<ElementAccum *>(sdS.data().get())),
Shape<Int<kBlockM>>{}); Shape<Int<kBlockM>>{});
auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_dO = typename Kernel_traits::GmemTiledCopydO{}.get_thread_slice(tidx); auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
auto gmem_thr_copy_dKV_accum = typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO;
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum;
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
...@@ -1176,8 +1208,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1176,8 +1208,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
Tensor tdKgdKaccum = gmem_thr_copy_dKV_accum.partition_D(gdKaccum); Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);
Tensor tdVgdVaccum = gmem_thr_copy_dKV_accum.partition_D(gdVaccum); Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);
typename Kernel_traits::TiledMmaSdP tiled_mma_sdp; typename Kernel_traits::TiledMmaSdP tiled_mma_sdp;
auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx); auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx);
...@@ -1204,32 +1236,39 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1204,32 +1236,39 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
// Copy Atom retiling // Copy Atom retiling
// //
auto smem_thr_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx); auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ); Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ);
Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO); Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO);
auto smem_thr_copy_KV = make_tiled_copy_B_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx); auto smem_tiled_copy_KV = make_tiled_copy_B_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_KV.partition_S(sK); Tensor tSsK = smem_thr_copy_KV.partition_S(sK);
Tensor tdPsV = smem_thr_copy_KV.partition_S(sV); Tensor tdPsV = smem_thr_copy_KV.partition_S(sV);
// Partition sP and sdS to match the accumulator partitioning // Partition sP and sdS to match the accumulator partitioning
// This has to be tiled_mma_sdp, not tiled_mma_dkv // This has to be tiled_mma_sdp, not tiled_mma_dkv
auto smem_thr_copy_PdS = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx); auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp);
auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx);
Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N)
auto smem_thr_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv).get_thread_slice(tidx); auto smem_tiled_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);
auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(tidx);
Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt); Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt);
Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt); Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt);
auto smem_thr_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv).get_thread_slice(tidx); auto smem_tiled_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);
auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx);
Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt); Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt);
Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt); Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt);
auto smem_thr_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq).get_thread_slice(tidx); auto smem_tiled_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq);
auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(tidx);
Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS); Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS);
auto smem_thr_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq).get_thread_slice(tidx); auto smem_tiled_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq);
auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(tidx);
Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt); Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt);
// //
...@@ -1263,15 +1302,15 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1263,15 +1302,15 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
// TODO: Might need to exit early and write 0 to gdQ. // TODO: Might need to exit early and write 0 to gdQ.
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
); );
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
); );
Tensor tQrQ = make_fragment_like(tQgQ); Tensor tQrQ = make_fragment_like(tQgQ);
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
); );
int n_block = n_block_max - 1; int n_block = n_block_max - 1;
...@@ -1282,10 +1321,10 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1282,10 +1321,10 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
} }
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
...@@ -1304,7 +1343,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1304,7 +1343,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
cute::cp_async_fence(); cute::cp_async_fence();
Tensor dP_sum = make_fragment_like(lse); Tensor dP_sum = make_fragment_like(lse);
copy(tdOrdO, tdOsdO); cute::copy(tdOrdO, tdOsdO);
dot_do_o<Kernel_traits::kGmemThreadsPerRow>( dot_do_o<Kernel_traits::kGmemThreadsPerRow>(
tdOrdO, tdOrO, sdPsum, sdPsum, tdOrdO, tdOrO, sdPsum, sdPsum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout
...@@ -1324,7 +1363,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1324,7 +1363,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, smem_thr_copy_QdO, smem_thr_copy_KV); flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
...@@ -1359,7 +1399,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1359,7 +1399,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
// if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8. // if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8.
Tensor tPrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMmaSdP>(rP.layout())); Tensor tPrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMmaSdP>(rP.layout()));
Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N)
copy(smem_thr_copy_PdS, tPaP, tPsP); cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);
Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N) Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA
...@@ -1367,7 +1407,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1367,7 +1407,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA
clear(acc_dp); clear(acc_dp);
flash::gemm(acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp, smem_thr_copy_QdO, smem_thr_copy_KV); flash::gemm(acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor dS = make_tensor(acc_dp.data(), scores.layout()); Tensor dS = make_tensor(acc_dp.data(), scores.layout());
...@@ -1386,7 +1427,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1386,7 +1427,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
// Convert dS from fp32 to fp16 // Convert dS from fp32 to fp16
Tensor tdSrdS = flash::convert_type<Element>(dS_reshaped); Tensor tdSrdS = flash::convert_type<Element>(dS_reshaped);
Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N)
copy(smem_thr_copy_PdS, tdSadS, tdSsdS); cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
__syncthreads(); __syncthreads();
if (n_block > 0) { if (n_block > 0) {
...@@ -1397,8 +1438,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1397,8 +1438,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
// Advance gK, gV // Advance gK, gV
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
...@@ -1406,7 +1447,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1406,7 +1447,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
clear(acc_dv); clear(acc_dv);
flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, smem_thr_copy_PdSt, smem_thr_copy_QdOt); flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(acc_dv); } // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(acc_dv); }
tdVgdVaccum.data() = tdVgdVaccum.data() + (-int(kBlockN * params.d_rounded)); tdVgdVaccum.data() = tdVgdVaccum.data() + (-int(kBlockN * params.d_rounded));
#pragma unroll #pragma unroll
...@@ -1415,12 +1457,14 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1415,12 +1457,14 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
__syncthreads(); __syncthreads();
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
clear(acc_dk); clear(acc_dk);
flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv, smem_thr_copy_PdSt, smem_thr_copy_QdOt); flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
tdKgdKaccum.data() = tdKgdKaccum.data() + (-int(kBlockN * params.d_rounded)); tdKgdKaccum.data() = tdKgdKaccum.data() + (-int(kBlockN * params.d_rounded));
#pragma unroll #pragma unroll
for (int i = 0; i < size(acc_dk); ++i) { atomicAdd(&tdKgdKaccum(i), acc_dk(i)); } for (int i = 0; i < size(acc_dk); ++i) { atomicAdd(&tdKgdKaccum(i), acc_dk(i)); }
flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq, smem_thr_copy_dS, smem_thr_copy_Kt); flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);
// Double buffer for sK // Double buffer for sK
tdQsKt.data() = tdQsKt.data() + (n_block % 2 == 0 ? size(sK) : -size(sK)); tdQsKt.data() = tdQsKt.data() + (n_block % 2 == 0 ? size(sK) : -size(sK));
...@@ -1436,12 +1480,13 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1436,12 +1480,13 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
Tensor sdQ = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutdQ{}); Tensor sdQ = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutdQ{});
// Partition sdV and sdK to match the accumulator partitioning // Partition sdV and sdK to match the accumulator partitioning
auto smem_thr_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq).get_thread_slice(tidx); auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
__syncthreads(); __syncthreads();
copy(smem_thr_copy_dQ, taccdQrdQ, taccdQsdQ); cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
...@@ -1449,14 +1494,15 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1449,14 +1494,15 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.dq_row_stride, _1{})); make_stride(params.dq_row_stride, _1{}));
auto gmem_thr_copy_dQ = typename Kernel_traits::GmemTiledCopydQ{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
__syncthreads(); __syncthreads();
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ)); Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
copy(gmem_thr_copy_dQ, tdQsdQ, tdQrdQ); cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
...@@ -1467,7 +1513,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1467,7 +1513,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
} }
// Clear_OOB_K must be false since we don't want to write zeros to gmem // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
); );
} }
......
...@@ -77,7 +77,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T ...@@ -77,7 +77,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
flash::reduce_sum(scores, scores_sum); flash::reduce_sum(scores, scores_sum);
} else { } else {
Tensor scores_max_prev = make_fragment_like(scores_max); Tensor scores_max_prev = make_fragment_like(scores_max);
copy(scores_max, scores_max_prev); cute::copy(scores_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, scores_max); flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
...@@ -103,7 +103,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T ...@@ -103,7 +103,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy> template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
inline __device__ void write_softmax_to_gmem( inline __device__ void write_softmax_to_gmem(
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_thr_copy_P Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_tiled_copy_P
) { ) {
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout l = tOrP.layout(); Layout l = tOrP.layout();
...@@ -112,7 +112,7 @@ inline __device__ void write_softmax_to_gmem( ...@@ -112,7 +112,7 @@ inline __device__ void write_softmax_to_gmem(
CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<1>(tPrP); ++mi) { for (int mi = 0; mi < size<1>(tPrP); ++mi) {
copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
} }
}; };
...@@ -186,8 +186,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -186,8 +186,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx); auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
...@@ -209,16 +211,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -209,16 +211,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Copy Atom retiling // Copy Atom retiling
// //
auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); // auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();} // if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSsK = smem_thr_copy_K.partition_S(sK);
auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
// TODO: this might need to change if we change the mma instruction in SM70 // TODO: this might need to change if we change the mma instruction in SM70
...@@ -269,7 +274,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -269,7 +274,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tQrQ = make_fragment_like(tQgQ); Tensor tQrQ = make_fragment_like(tQgQ);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM); binfo.actual_seqlen_q - m_block * kBlockM);
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
...@@ -286,13 +291,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -286,13 +291,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads(); __syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads(); __syncthreads();
} }
int n_block = n_block_max - 1; int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_N, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, flash::copy<Is_even_N, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN); binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence(); cute::cp_async_fence();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
...@@ -303,7 +308,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -303,7 +308,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads(); __syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
} }
auto seeds = at::cuda::philox::unpack(params.philox_args); auto seeds = at::cuda::philox::unpack(params.philox_args);
...@@ -335,17 +340,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -335,17 +340,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Advance gV // Advance gV
if (masking_step > 0) { if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else { } else {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
} }
cute::cp_async_fence(); cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
); );
// if (cute::thread0()) { print(acc_s); } // if (cute::thread0()) { print(acc_s); }
...@@ -382,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -382,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
if (n_block > 0) { if (n_block > 0) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
...@@ -402,12 +408,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -402,12 +408,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
uint32_t block_col_idx = n_block * (kBlockN / 32); uint32_t block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) { if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP); Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy); cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>( flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps block_row_idx, block_col_idx, kNWarps
); );
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN); tPgP.data() = tPgP.data() + (-kBlockN);
} }
if (Is_dropout) { if (Is_dropout) {
...@@ -416,7 +422,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -416,7 +422,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
} }
// if (cute::thread0()) { print(tOrP); } // if (cute::thread0()) { print(tOrP); }
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration // This check is at the end of the loop since we always have at least 1 iteration
...@@ -434,11 +440,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -434,11 +440,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads(); __syncthreads();
// Advance gV // Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence(); cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
); );
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
...@@ -446,7 +453,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -446,7 +453,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
if (n_block > 0) { if (n_block > 0) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
...@@ -464,12 +471,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -464,12 +471,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
uint32_t block_col_idx = n_block * (kBlockN / 32); uint32_t block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) { if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP); Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy); cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>( flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps block_row_idx, block_col_idx, kNWarps
); );
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN); tPgP.data() = tPgP.data() + (-kBlockN);
} }
if (Is_dropout) { if (Is_dropout) {
...@@ -477,7 +484,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -477,7 +484,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
block_row_idx, block_col_idx, kNWarps); block_row_idx, block_col_idx, kNWarps);
} }
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
} }
// Epilogue // Epilogue
...@@ -501,7 +508,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -501,7 +508,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor rO = flash::convert_type<Element>(acc_o); Tensor rO = flash::convert_type<Element>(acc_o);
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning // Partition sO to match the accumulator partitioning
auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
...@@ -509,7 +517,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -509,7 +517,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// sO has the same size as sQ, so we don't need to sync here. // sO has the same size as sQ, so we don't need to sync here.
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
copy(smem_thr_copy_O, taccOrO, taccOsO); cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
...@@ -520,14 +528,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -520,14 +528,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse), Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{}); Shape<Int<kBlockM>>{}, Stride<_1>{});
auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgO = gmem_thr_copy_O.partition_D(gO); Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
__syncthreads(); __syncthreads();
Tensor tOrO = make_tensor<Element>(shape(tOgO)); Tensor tOrO = make_tensor<Element>(shape(tOgO));
copy(gmem_thr_copy_O, tOsO, tOrO); cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
...@@ -554,7 +563,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -554,7 +563,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
} }
// Clear_OOB_K must be false since we don't want to write zeros to gmem // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
); );
} }
......
...@@ -173,10 +173,12 @@ static __device__ inline T run(T x, Operator &op) { ...@@ -173,10 +173,12 @@ static __device__ inline T run(T x, Operator &op) {
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1, template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4, typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopy0, typename TiledCopy1> typename TiledMma, typename TiledCopyA, typename TiledCopyB,
typename ThrCopyA, typename ThrCopyB>
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma, Tensor4 const& tCsB, TiledMma tiled_mma,
TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) { TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
...@@ -184,13 +186,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 ...@@ -184,13 +186,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
#pragma unroll #pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) { for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) { if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
} }
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
} }
...@@ -199,19 +201,20 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 ...@@ -199,19 +201,20 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy> typename TiledMma, typename TiledCopy, typename ThrCopy>
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_thr_copy_B) { TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll #pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) { for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) { if (i < size<2>(tCrA) - 1) {
copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
} }
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
} }
...@@ -319,7 +322,7 @@ void cp_async_wait() { ...@@ -319,7 +322,7 @@ void cp_async_wait() {
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true, template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3> typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &S, inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) { Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
...@@ -335,13 +338,13 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const & ...@@ -335,13 +338,13 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
#pragma unroll #pragma unroll
for (int k = 0; k < size<2>(S); ++k) { for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) { if (Is_even_K || predicate_K(k)) {
copy(thr_copy, S(_, m, k), D(_, m, k)); cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) { } else if (Clear_OOB_K) {
clear(D(_, m, k)); cute::clear(D(_, m, k));
} }
} }
} else if (Clear_OOB_MN) { } else if (Clear_OOB_MN) {
clear(D(_, m, _)); cute::clear(D(_, m, _));
} }
} }
// TD [2023-04-13]: Strange that the code below can cause race condition. // TD [2023-04-13]: Strange that the code below can cause race condition.
...@@ -350,7 +353,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const & ...@@ -350,7 +353,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll // #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) { // for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(thr_copy, S(_, m, _), D(_, m, _)); // copy(tiled_copy, S(_, m, _), D(_, m, _));
// } else if (Clear_OOB_MN) { // } else if (Clear_OOB_MN) {
// clear(D(_, m, _)); // clear(D(_, m, _));
// } // }
...@@ -362,7 +365,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const & ...@@ -362,7 +365,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll // #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) { // for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(thr_copy, S(_, m, k), D(_, m, k)); // copy(tiled_copy, S(_, m, k), D(_, m, k));
// } else if (Clear_OOB_MN) { // } else if (Clear_OOB_MN) {
// clear(D(_, m, k)); // clear(D(_, m, k));
// } // }
......
...@@ -783,13 +783,13 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_ ...@@ -783,13 +783,13 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
# @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) # @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
@pytest.mark.parametrize('d', [128]) # @pytest.mark.parametrize('d', [128])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@pytest.mark.parametrize('seqlen', [128]) @pytest.mark.parametrize('seqlen', [128])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment