Commit 2416ddf7 authored by ltqin's avatar ltqin
Browse files

first

parent 38f48480
......@@ -15,6 +15,7 @@ add_example_executable(example_grouped_multihead_attention_forward_v2 grouped_mu
add_example_executable(example_batched_multihead_attention_forward_v2 batched_multihead_attention_forward_v2.cpp)
add_example_executable(example_grouped_multihead_attention_backward_v2 grouped_multihead_attention_backward_v2.cpp)
add_example_executable(example_batched_multihead_attention_backward_v2 batched_multihead_attention_backward_v2.cpp)
add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp)
add_example_executable(example_batched_multihead_attention_backward_v2_phased batched_multihead_attention_backward_v2_phased.cpp)
add_example_executable(example_grouped_multihead_attention_train_v2 grouped_multihead_attention_train_v2.cpp)
add_example_executable(example_batched_multihead_attention_train_v2 batched_multihead_attention_train_v2.cpp)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
Computation graph:
K^T V
| |
| |
Q --- * ----- Softmax ----- * --> Y
S P
Kernel inputs:
Q, K, V, Y, dY, per-row softmax stats (LSE)
Kernel outputs:
dQ, dK, dV
*/
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <fstream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
using INT32 = int32_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough;
using YElementOp = PassThrough;
using InputDataType = F16;
using OutputDataType = F16;
using GemmDataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
using ZDataType = U16; // INT32
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle;
#else
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
#endif
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1.
// If 32 < DIM <= 64 , ues prototype1.
// If 64 < DIM <= 128, ues prototype2.
#if(DIM <= 32)
// clang-format off
using DeviceGemmInstance =
// ##################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ##################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ##################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ##################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 4, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on
#elif(DIM <= 64)
// clang-format off
using DeviceGemmInstance =
// ##################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ##################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ##################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ##################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 64, 32, 8, 8, 2, 32, 32, 2, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ##################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ##################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ##################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ##################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, false, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on
#elif(DIM <= 128)
// clang-format off
using DeviceGemmInstance =
// ##################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ##################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ##################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ##################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 4, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 64, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 64, 128, 128, 128, 32, 8, 8, 2, 32, 32, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 128, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 32, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 1, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on
#endif
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
InputDataType,
AccDataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
// Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, InputDataType, AccDataType>;
// Ref Gemm1: Y = P * V
// fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
InputDataType,
InputDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
// Ref Gemm for backward pass
// fp16 in, fp16 out
using ReferenceGemm0GradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
InputDataType,
InputDataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
InputDataType,
OutputDataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ZDataType, InputDataType, InputDataType>;
template <typename TensorQ,
typename TensorK,
typename TensorV,
typename TensorS,
typename TensorP,
typename TensorZ,
typename TensorY,
typename TensorLSE = TensorP>
void run_attention_fwd_host(const TensorQ& q_g_m_k,
const TensorK& k_g_n_k,
const TensorV& v_g_n_o,
const float alpha,
TensorS& s_g_m_n,
TensorP& p_g_m_n,
TensorY& y_g_m_o,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits,
float rp_dropout)
{
// S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
q_g_m_k, k_g_k_n, s_g_m_n, PassThrough{}, PassThrough{}, Scale{alpha});
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
// P = Softmax(S)
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(s_g_m_n, p_g_m_n, 1, 0, {2}, &lse_g_m);
ref_softmax_invoker.Run(ref_softmax_argument);
// P_dropped
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
p_drop_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
}
int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 2; // method 1 will have slightly higher error; TODO: to investigate
bool time_kernel = true;
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4;
ck::index_t G1 = 6;
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.0;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
p_drop = std::stof(argv[10]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
exit(0);
}
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl;
std::cout << "M: " << M << std::endl;
std::cout << "N: " << N << std::endl;
std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl;
std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl;
std::cout << "p_drop: " << p_drop << std::endl;
std::cout << "seed: " << seed << std::endl;
std::cout << "offset: " << offset << std::endl;
const ck::index_t BatchCount = G0 * G1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
// y dot ygrad sum
std::vector<ck::index_t> d_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> d_gs_ms_strides{
G1 * M, M, 1}; // O dot Ograd row sum layout [G0, G1, M]
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<AccDataType> d_gs_ms(d_gs_ms_lengths, d_gs_ms_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
std::cout << "d_gs_ms_os: " << d_gs_ms.mDesc << std::endl;
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method)
{
case 0: break;
case 1:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
break;
case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
break;
case 3:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-5, 5});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
break;
case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
break;
case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
break;
case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = [0, 1, 2, ...; 0, 1, 2, ...; ...]
// dO dot O = [127.5; ...]
// dS = P * (dP - dO dot O)
//
break;
default:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = ones
// dS = P * (dP - (dO dot O))
// = 0.0039 * ones * (ones - 0.0039*256)
// = 0.0039 * ones * (ones - 1)
// = 0
}
Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
Tensor<InputDataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
// qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(InputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(InputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize());
DeviceMem qgrad_device_buf(sizeof(OutputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(AccDataType) * d_gs_ms.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// get z matrix
{
auto argument = gemm.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<AccDataType*>(d_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
k_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
v_gs_os_ns_lengths,
v_gs_os_ns_strides,
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
d_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
invoker.Run(argument, StreamConfig{nullptr, false});
}
// not need output z matrix
auto argument = gemm.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<AccDataType*>(d_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
k_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
v_gs_os_ns_lengths,
v_gs_os_ns_strides,
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
d_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
qgrad_device_buf.SetZero();
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// 5 GEMM ops in total:
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
// dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
// 3x MNK + 2x MNO
std::size_t flop = (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
std::size_t num_btype =
(sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) +
sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N +
sizeof(OutputDataType) * N * O) *
BatchCount +
sizeof(LSEDataType) * M * BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
// copy z matirx data form device
z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool pass = true;
if(do_verification)
{
// run fwd again for y, cause z_g_m_n update
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
v_g_n_o,
alpha,
s_g_m_n,
p_g_m_n,
y_g_m_o,
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
});
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data());
// call kernel again
qgrad_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
Tensor<OutputDataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<OutputDataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<OutputDataType> vgrad_g_n_o({BatchCount, N, O});
Tensor<InputDataType> sgrad_g_m_n({BatchCount, M, N});
Tensor<InputDataType> pgrad_g_m_n({BatchCount, M, N});
Tensor<InputDataType> pgrad_drop_g_m_n({BatchCount, M, N});
Tensor<InputDataType> ygrad_g_m_o({BatchCount, M, O});
Tensor<InputDataType> ygrad_dot_y_g_m({BatchCount, M});
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
#if PRINT_HOST
{
std::cout << "q_g_m_k ref:\n" << q_g_m_k;
std::cout << "k_g_n_k ref:\n" << k_g_n_k;
std::cout << "v_g_n_o ref:\n" << v_g_n_o;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
}
#endif
// Gradients
auto ref_gemm0_grad = ReferenceGemm0GradInstance{};
auto ref_gemm0_grad_invoker = ref_gemm0_grad.MakeInvoker();
using RefGemm0GradArg = ReferenceGemm0GradInstance::Argument;
auto ref_gemm1_grad = ReferenceGemm1GradInstance{};
auto ref_gemm1_grad_invoker = ref_gemm1_grad.MakeInvoker();
using RefGemm1GradArg = ReferenceGemm1GradInstance::Argument;
// dP_dropout = dY * V^T
auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1});
ref_gemm0_grad_invoker.Run(RefGemm0GradArg{
ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
#if PRINT_HOST
{
std::cout << "===== dP = dY * V^T\n";
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "v_g_o_n ref:\n" << v_g_o_n;
std::cout << "pgrad_drop_g_m_n ref:\n" << pgrad_drop_g_m_n;
}
#endif
// dP = dP_dropout x Z
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
float ygrad_dot_y = 0;
for(int o = 0; o < O; o++)
{
auto idx_gmo = idx_gmn;
idx_gmo[2] = o;
ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_o(idx_gmo));
}
self(idx_gmn) = ck::type_convert<InputDataType>(
ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
});
#if PRINT_HOST
{
std::cout << "===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)\n";
std::cout << "p_g_m_n ref:\n" << p_g_m_n;
std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
std::cout << "y_g_m_o ref:\n" << y_g_m_o;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
}
#endif
// dV = P_drop^T * dY
auto p_drop_g_n_m = p_drop_g_m_n.Transpose({0, 2, 1});
ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
#if PRINT_HOST
{
std::cout << "===== dV = P^T * dY\n";
std::cout << "p_drop_g_n_m ref:\n" << p_drop_g_n_m;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o;
}
#endif
// dQ = alpha * dS * K
ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
#if PRINT_HOST
{
std::cout << "===== dQ = alpha * dS * K\n";
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
std::cout << "k_g_n_k ref:\n" << k_g_n_k;
std::cout << "qgrad_g_m_k ref:\n" << qgrad_g_m_k;
}
#endif
// dK = alpha * dS^T * Q
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
#if PRINT_HOST
{
std::cout << "===== dK = alpha * dS^T * Q\n";
std::cout << "sgrad_g_n_m ref:\n" << sgrad_g_n_m;
std::cout << "q_g_m_k ref:\n" << q_g_m_k;
std::cout << "kgrad_g_n_k ref:\n" << kgrad_g_n_k;
}
#endif
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
vgrad_device_buf.FromDevice(vgrad_gs_os_ns_device_result.mData.data());
// permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
});
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
});
std::cout << "Checking qgrad:\n";
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
qgrad_gs_ms_ks_host_result.mData,
"error",
1e-2,
1e-2);
std::cout << "Checking kgrad:\n";
pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
kgrad_gs_ns_ks_host_result.mData,
"error",
1e-2,
1e-2);
std::cout << "Checking vgrad:\n";
pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
vgrad_gs_os_ns_host_result.mData,
"error",
1e-2,
1e-2);
}
return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
}
int main(int argc, char* argv[]) { return run(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
// #include "ck/tensor_operation/gpu/device/device_batched_multihead_attention_backward.hpp" // TODO
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_bacckward_ydotygrad.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename InputDataType,
typename DDataType,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename DGridDescriptor_M,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_ydotygrad_v1(
const InputDataType* __restrict__ p_y_grid,
const InputDataType* __restrict__ p_ygrad_grid,
DDataType* __restrict__ p_d_grid,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_M d_grid_desc_m,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// offsets
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
// GridwiseGemm::test();
GridwiseGemm::Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset,
p_d_grid + d_batch_offset,
p_shared,
c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m,
block_2_ctile_map);
#else
ignore = p_y_grid;
ignore = p_ygrad_grid;
ingore = p_d_grid;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_m;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename GridwiseGemm,
typename InputDataType,
typename OutputDataType,
typename ZDataType,
typename LSEDataType,
typename DDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename B1GridDesc_BK0_N_BK1,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M,
typename DGridDescriptor_M,
typename YGradGridDesc_O0_M_O1,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop,
bool Deterministic>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1(
const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_b1_grid,
const InputDataType* __restrict__ p_c_grid,
const LSEDataType* __restrict__ p_lse_grid,
const DDataType* __restrict__ p_d_grid,
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const LSEGridDescriptor_M lse_grid_desc_m,
const DGridDescriptor_M d_grid_desc_m,
const YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const index_t nblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const float p_drop,
const unsigned long long seed,
const unsigned long long offset,
const index_t raw_m_padded,
const index_t raw_n_padded)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// offsets
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
ck::philox ph(seed, 0, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
if constexpr(Deterministic)
{
for(index_t i = 0; i < nblock; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset,
p_d_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
d_grid_desc_m,
ygrad_grid_desc_o0_m_o1,
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph,
z_random_matrix_offset,
raw_n_padded,
i);
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset,
p_d_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
d_grid_desc_m,
ygrad_grid_desc_o0_m_o1,
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph,
z_random_matrix_offset,
raw_n_padded,
0);
}
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
ignore = b1_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = p_drop;
ignore = seed;
ignore = offset;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename InputDataType,
typename OutputDataType,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
TensorSpecialization ASpec,
TensorSpecialization BSpec,
TensorSpecialization B1Spec,
TensorSpecialization CSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock, // Gemm0NPerBlock
index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t AK1,
index_t BK1,
index_t B1K1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1;
using DDataType = GemmAccDataType;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t Q_K1 = 8;
static constexpr index_t K_K1 = 8;
static constexpr index_t V_N1 = 2;
static constexpr index_t Q_M1 = 2;
static constexpr index_t K_N1 = 2;
static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2;
static constexpr auto padder = GemmGemmPadder<GemmSpec,
Number<MPerBlock>,
Number<NPerBlock>,
Number<KPerBlock>,
Number<Gemm1NPerBlock>>{};
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpec,
ASpec,
BSpec,
B1Spec,
CSpec>;
/*
Descriptors for inputs:
Q, K, V, Y, dY, per-row softmax stats
Descriptors for outputs:
dQ, dK, dV
*/
// Q in Gemm A position
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<AK1>{});
}
// K in Gemm B0 position
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
Number<BK1>{});
}
// V in Gemm B1 position
static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
b1_gs_gemm1ns_gemm1ks_strides_vec),
Number<B1K1>{});
}
//
// dV = P^T * dY
//
// VGrad in Gemm C position
static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths_vec,
const std::vector<index_t>& v_gs_os_ns_strides_vec)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const index_t num_dims = NumDimG + NumDimN + NumDimO;
// 0, 1, .. NumDimG - 1
std::vector<index_t> gs_ids(NumDimG);
std::iota(gs_ids.begin(), gs_ids.end(), 0);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std::vector<index_t> os_ids(NumDimO);
std::iota(os_ids.begin(), os_ids.end(), NumDimG);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std::vector<index_t> ns_ids(NumDimN);
std::iota(ns_ids.begin(), ns_ids.end(), NumDimG + NumDimO);
std::vector<index_t> ids_old2new;
ids_old2new.insert(ids_old2new.end(), gs_ids.begin(), gs_ids.end());
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims);
for(int i = 0; i < num_dims; i++)
{
index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new];
}
const auto vgrad_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec)
.second;
return PadTensorDescriptor(vgrad_desc_nraw_oraw,
make_tuple(NPerBlock, Gemm1NPerBlock),
Sequence<padder.PadN, padder.PadO>{});
}
template <typename YGridDesc_M_O>
static auto MakeYGradGridDescriptor_M0_O_M1(const YGridDesc_M_O& ygrad_grid_desc_m_o)
{
const auto M = ygrad_grid_desc_m_o.GetLength(I0);
const auto O = ygrad_grid_desc_m_o.GetLength(I1);
const auto Y_M0 = M / Y_M1;
return transform_tensor_descriptor(
ygrad_grid_desc_m_o,
make_tuple(make_unmerge_transform(make_tuple(Y_M0, Y_M1)),
make_pass_through_transform(O)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
//
// dP = dY * V^T
//
// YGrad in Gemm A position
static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths_vec,
const std::vector<index_t>& y_gs_ms_os_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths_vec, y_gs_ms_os_strides_vec),
Number<Y_O1>{});
}
// V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths_vec,
const std::vector<index_t>& v_gs_os_ns_strides_vec)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const index_t num_dims = NumDimG + NumDimN + NumDimO;
// 0, 1, .. NumDimG - 1
std::vector<index_t> gs_ids(NumDimG);
std::iota(gs_ids.begin(), gs_ids.end(), 0);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std::vector<index_t> os_ids(NumDimO);
std::iota(os_ids.begin(), os_ids.end(), NumDimG);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std::vector<index_t> ns_ids(NumDimN);
std::iota(ns_ids.begin(), ns_ids.end(), NumDimG + NumDimO);
std::vector<index_t> ids_old2new;
ids_old2new.insert(ids_old2new.end(), gs_ids.begin(), gs_ids.end());
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims);
for(int i = 0; i < num_dims; i++)
{
index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new];
}
const auto v_grid_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec)
.second;
const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw,
make_tuple(NPerBlock, Gemm1NPerBlock),
Sequence<padder.PadN, padder.PadO>{});
// N_O to O0_N_O1; to refactor
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
}
// Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec,
const std::vector<index_t>& z_gs_ms_ns_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths_vec,
const std::vector<index_t>& q_gs_ms_ks_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths_vec, q_gs_ms_ks_strides_vec);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths_vec,
const std::vector<index_t>& k_gs_ns_ks_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths_vec, k_gs_ns_ks_strides_vec);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
{
const auto lse_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(lse_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return lse_grid_desc_mraw;
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using DGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
constexpr static auto make_MaskOutPredicate()
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
return MaskDisabledPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
return MaskOutUpperTrianglePredicate{};
}
}
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n,
index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
BatchStrideLSE_(BatchStrideLSE)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
{
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
{
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetLSEBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
}
private:
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
index_t BatchStrideLSE_;
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
ZDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
LSEDataType,
DDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
KGridDesc_N_K,
ZGridDesc_M_N,
B1GridDesc_BK0_N_BK1,
YGridDesc_M_O,
LSEGridDesc_M,
DGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
Gemm2NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
true,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
// GridwiseYDotYGrad
using GridwiseYDotYGrad =
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B
// datatype
DDataType,
YGridDesc_M_O,
DGridDesc_M,
BlockSize,
BlockSize,
32>;
// Argument
struct Argument : public BaseArgument
{
Argument(
const InputDataType* p_a_grid,
const InputDataType* p_b_grid,
ZDataType* p_z_grid,
const InputDataType* p_b1_grid,
const InputDataType* p_c_grid, // for dS
const LSEDataType* p_lse_grid,
DDataType* p_d_grid,
const InputDataType* p_ygrad_grid,
OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& d_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_z_grid_{p_z_grid},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
p_lse_grid_{p_lse_grid},
p_d_grid_{p_d_grid},
p_ygrad_grid_{p_ygrad_grid},
p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
d_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(d_gs_ms_lengths[NumDimG])},
k_grid_desc_n_k_{
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
ygrad_grid_desc_o0_m_o1_{DeviceOp::MakeYGradGridDescriptor_O0_M_O1(
c_gs_ms_gemm1ns_lengths, c_gs_ms_gemm1ns_strides)},
// batch offsets
a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_g_n_k_{
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o_)},
d_y_grid_desc_mblock_mperblock_oblock_operblock_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]},
a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1],
b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1],
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())},
p_drop_{p_drop}
{
// TODO: implement bias addition
ignore = p_acc0_biases;
ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_,
y_grid_desc_m_o_))
{
y_grid_desc_mblock_mperblock_oblock_operblock_ =
GridwiseGemm::MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(
y_grid_desc_m_o_);
}
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
d_y_grid_desc_mblock_mperblock_oblock_operblock_ =
GridwiseYDotYGrad::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o_);
// Print();
m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
n_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
}
void Print() const
{
std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
<< a_grid_desc_g_m_k_.GetLength(I1) << ", "
<< a_grid_desc_g_m_k_.GetLength(I2) << '\n';
// a_grid_desc_g_m_k_.Print();
std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b_grid_desc_g_n_k_.Print();
std::cout << "b1_grid_desc_g_o_n_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1_grid_desc_g_n_k_.Print();
std::cout << "c_grid_desc_g_m_o_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
<< c_grid_desc_g_m_n_.GetLength(I1) << ", "
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n';
// c_grid_desc_g_m_n_.Print();
std::cout << "k_grid_desc_n_k_: " << k_grid_desc_n_k_.GetLength(I0) << ", "
<< k_grid_desc_n_k_.GetLength(I1) << '\n';
std::cout << "ygrad_grid_desc_o0_m_o1_: " << ygrad_grid_desc_o0_m_o1_.GetLength(I0)
<< ", " << ygrad_grid_desc_o0_m_o1_.GetLength(I1) << ", "
<< ygrad_grid_desc_o0_m_o1_.GetLength(I2) << '\n';
}
// pointers
const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_;
ZDataType* p_z_grid_;
const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_;
DDataType* p_d_grid_;
const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_;
DGridDesc_M d_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_;
YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1_;
// batch offsets
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
d_y_grid_desc_mblock_mperblock_oblock_operblock_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
// check C0 masking and padding
C0MatrixMask c0_matrix_mask_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
std::vector<index_t> a_mz_kz_strides_;
std::vector<index_t> b_nz_kz_strides_;
std::vector<index_t> b1_nz_kz_strides_;
std::vector<index_t> c_mz_gemm1nz_strides_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_drop_;
unsigned long long seed_;
unsigned long long offset_;
index_t m_raw_padded_;
index_t n_raw_padded_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!DeviceOp::IsSupportedArgument(arg))
{
throw std::runtime_error("wrong! unsupported argument");
}
float ave_time = 0;
{
const index_t grid_size =
(Deterministic
? 1
: arg.d_block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_)) *
arg.batch_count_;
std::cout << "grid_size: " << grid_size
<< "grid_size / arg.batch_count_: " << grid_size / arg.batch_count_
<< " arg.batch_count_: " << arg.batch_count_ << std::endl;
std::cout << "MPerBlock: " << MPerBlock << " Gemm1NPerBlock: " << Gemm1NPerBlock
<< std::endl;
std::cout << "arg.y_grid_desc_m_o_: {" << arg.y_grid_desc_m_o_.GetLength(I0) << ","
<< arg.y_grid_desc_m_o_.GetLength(I1) << "}" << std::endl;
auto launch_kernel = [&]() {
const auto kernel = kernel_batched_multihead_attention_backward_ydotygrad_v1<
GridwiseYDotYGrad,
InputDataType,
DDataType,
typename GridwiseYDotYGrad::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::DGridDesc_M,
typename GridwiseYDotYGrad::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_c_grid_,
arg.p_ygrad_grid_,
arg.p_d_grid_,
arg.d_y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.d_grid_desc_m_,
arg.d_block_2_ctile_map_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_);
};
ave_time = launch_kernel();
}
const index_t grid_size =
(Deterministic ? 1
: arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_)) *
arg.batch_count_;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm,
InputDataType,
OutputDataType,
ZDataType,
LSEDataType,
DDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M,
DeviceOp::DGridDesc_M,
DeviceOp::YGradGridDesc_O0_M_O1,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_z_grid_,
arg.p_b1_grid_,
arg.p_c_grid_,
arg.p_lse_grid_,
arg.p_d_grid_,
arg.p_ygrad_grid_,
arg.p_qgrad_grid_,
arg.p_kgrad_grid_,
arg.p_vgrad_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
arg.b1_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_,
arg.d_grid_desc_m_,
arg.ygrad_grid_desc_o0_m_o1_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_),
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_,
arg.p_drop_,
arg.seed_,
arg.offset_,
arg.m_raw_padded_,
arg.n_raw_padded_);
};
ave_time += launch_kernel(integral_constant<bool, false>{});
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
// TODO: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{
return false;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
const auto c_extent_lowest = Gemm1NzRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
const auto b_stride_lowest =
BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0];
const auto c_stride_lowest =
arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || c_stride_lowest == 1))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_m_o_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const InputDataType* p_a,
const InputDataType* p_b,
ZDataType* p_z,
const InputDataType* p_b1,
const InputDataType* p_c,
const LSEDataType* p_lse,
DDataType* p_d_grid,
const InputDataType* p_ygrad_grid,
OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& d_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_a,
p_b,
p_z,
p_b1,
p_c,
p_lse,
p_d_grid,
p_ygrad_grid,
p_qgrad_grid,
p_kgrad_grid,
p_vgrad_grid,
p_acc0_biases,
p_acc1_biases,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths,
d_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
p_drop,
seeds};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
// FIXME: constness
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
void* p_z,
const void* p_b1,
const void* p_c,
const void* p_lse,
void* p_d_grid,
const void* p_ygrad_grid,
void* p_qgrad_grid,
void* p_kgrad_grid,
void* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& d_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override
{
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_a),
static_cast<const InputDataType*>(p_b),
static_cast<ZDataType*>(p_z),
static_cast<const InputDataType*>(p_b1),
static_cast<const InputDataType*>(p_c),
static_cast<const LSEDataType*>(p_lse),
static_cast<DDataType*>(p_d_grid),
static_cast<const InputDataType*>(p_ygrad_grid),
static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<OutputDataType*>(p_kgrad_grid),
static_cast<OutputDataType*>(p_vgrad_grid),
p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths,
d_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
p_drop,
seeds);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() // override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerBlock << ", "
<< Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", "
<< B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec) << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_bacckward_ydotygrad.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename InputDataType,
typename DDataType,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename DGridDescriptor_M,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_ydotygrad_v2(
const InputDataType* __restrict__ p_y_grid,
const InputDataType* __restrict__ p_ygrad_grid,
DDataType* __restrict__ p_d_grid,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_M d_grid_desc_m,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// offsets
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
// GridwiseGemm::test();
GridwiseGemm::Run(p_y_grid + c_batch_offset,
p_ygrad_grid + c_batch_offset,
p_d_grid + d_batch_offset,
p_shared,
c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_m,
block_2_ctile_map);
#else
ignore = p_y_grid;
ignore = p_ygrad_grid;
ingore = p_d_grid;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_m;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename GridwiseGemm,
typename InputDataType,
typename OutputDataType,
typename ZDataType,
typename LSEDataType,
typename DDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename B1GridDesc_BK0_N_BK1,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M,
typename DGridDescriptor_M,
typename YGradGridDesc_M0_O_M1,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop,
bool Deterministic>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2(
const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_b1_grid,
const InputDataType* __restrict__ p_c_grid,
const LSEDataType* __restrict__ p_lse_grid,
const DDataType* __restrict__ p_d_grid,
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const LSEGridDescriptor_M lse_grid_desc_m,
const DGridDescriptor_M d_grid_desc_m,
const YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const index_t nblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const float p_drop,
const unsigned long long seed,
const unsigned long long offset,
const index_t raw_m_padded,
const index_t raw_n_padded)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// offsets
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
ck::philox ph(seed, 0, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
if constexpr(Deterministic)
{
for(index_t i = 0; i < nblock; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset,
p_d_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
d_grid_desc_m,
ygrad_grid_desc_m0_o_m1,
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph,
z_random_matrix_offset,
raw_n_padded,
i);
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset,
p_d_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
d_grid_desc_m,
ygrad_grid_desc_m0_o_m1,
block_2_ctile_map,
c0_matrix_mask,
p_drop,
ph,
z_random_matrix_offset,
raw_n_padded,
0);
}
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
ignore = b1_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = p_drop;
ignore = seed;
ignore = offset;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename InputDataType,
typename OutputDataType,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
TensorSpecialization ASpec,
TensorSpecialization BSpec,
TensorSpecialization B1Spec,
TensorSpecialization CSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock, // Gemm0NPerBlock
index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t AK1,
index_t BK1,
index_t B1K1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1BlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2;
using DDataType = GemmAccDataType;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t Q_K1 = 8;
static constexpr index_t K_K1 = 8;
static constexpr index_t V_N1 = 2;
static constexpr index_t Q_M1 = 2;
static constexpr index_t K_N1 = 2;
static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2;
static constexpr auto padder = GemmGemmPadder<GemmSpec,
Number<MPerBlock>,
Number<NPerBlock>,
Number<KPerBlock>,
Number<Gemm1NPerBlock>>{};
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpec,
ASpec,
BSpec,
B1Spec,
CSpec>;
/*
Descriptors for inputs:
Q, K, V, Y, dY, per-row softmax stats
Descriptors for outputs:
dQ, dK, dV
*/
// Q in Gemm A position
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<AK1>{});
}
// K in Gemm B0 position
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
Number<BK1>{});
}
// V in Gemm B1 position
static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
b1_gs_gemm1ns_gemm1ks_strides_vec),
Number<B1K1>{});
}
//
// dV = P^T * dY
//
// VGrad in Gemm C position
static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths_vec,
const std::vector<index_t>& v_gs_os_ns_strides_vec)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const index_t num_dims = NumDimG + NumDimN + NumDimO;
// 0, 1, .. NumDimG - 1
std::vector<index_t> gs_ids(NumDimG);
std::iota(gs_ids.begin(), gs_ids.end(), 0);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std::vector<index_t> os_ids(NumDimO);
std::iota(os_ids.begin(), os_ids.end(), NumDimG);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std::vector<index_t> ns_ids(NumDimN);
std::iota(ns_ids.begin(), ns_ids.end(), NumDimG + NumDimO);
std::vector<index_t> ids_old2new;
ids_old2new.insert(ids_old2new.end(), gs_ids.begin(), gs_ids.end());
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims);
for(int i = 0; i < num_dims; i++)
{
index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new];
}
const auto vgrad_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec)
.second;
return PadTensorDescriptor(vgrad_desc_nraw_oraw,
make_tuple(NPerBlock, Gemm1NPerBlock),
Sequence<padder.PadN, padder.PadO>{});
}
template <typename YGridDesc_M_O>
static auto MakeYGradGridDescriptor_M0_O_M1(const YGridDesc_M_O& ygrad_grid_desc_m_o)
{
const auto M = ygrad_grid_desc_m_o.GetLength(I0);
const auto O = ygrad_grid_desc_m_o.GetLength(I1);
const auto Y_M0 = M / Y_M1;
return transform_tensor_descriptor(
ygrad_grid_desc_m_o,
make_tuple(make_unmerge_transform(make_tuple(Y_M0, Y_M1)),
make_pass_through_transform(O)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
//
// dP = dY * V^T
//
// YGrad in Gemm A position
static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths_vec,
const std::vector<index_t>& y_gs_ms_os_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths_vec, y_gs_ms_os_strides_vec),
Number<Y_O1>{});
}
// V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths_vec,
const std::vector<index_t>& v_gs_os_ns_strides_vec)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const index_t num_dims = NumDimG + NumDimN + NumDimO;
// 0, 1, .. NumDimG - 1
std::vector<index_t> gs_ids(NumDimG);
std::iota(gs_ids.begin(), gs_ids.end(), 0);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std::vector<index_t> os_ids(NumDimO);
std::iota(os_ids.begin(), os_ids.end(), NumDimG);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std::vector<index_t> ns_ids(NumDimN);
std::iota(ns_ids.begin(), ns_ids.end(), NumDimG + NumDimO);
std::vector<index_t> ids_old2new;
ids_old2new.insert(ids_old2new.end(), gs_ids.begin(), gs_ids.end());
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims);
for(int i = 0; i < num_dims; i++)
{
index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new];
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new];
}
const auto v_grid_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec)
.second;
const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw,
make_tuple(NPerBlock, Gemm1NPerBlock),
Sequence<padder.PadN, padder.PadO>{});
// N_O to O0_N_O1; to refactor
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
}
// Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec,
const std::vector<index_t>& z_gs_ms_ns_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths_vec,
const std::vector<index_t>& q_gs_ms_ks_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths_vec, q_gs_ms_ks_strides_vec);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths_vec,
const std::vector<index_t>& k_gs_ns_ks_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths_vec, k_gs_ns_ks_strides_vec);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
{
const auto lse_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(lse_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return lse_grid_desc_mraw;
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using DGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
constexpr static auto make_MaskOutPredicate()
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
return MaskDisabledPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
return MaskOutUpperTrianglePredicate{};
}
}
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n,
index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
BatchStrideLSE_(BatchStrideLSE)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
{
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
{
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetLSEBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
}
private:
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
index_t BatchStrideLSE_;
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
ZDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
LSEDataType,
DDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
KGridDesc_N_K,
ZGridDesc_M_N,
B1GridDesc_BK0_N_BK1,
YGridDesc_M_O,
LSEGridDesc_M,
DGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
Gemm2NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
true,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
true,
B1BlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
// GridwiseYDotYGrad
using GridwiseYDotYGrad =
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, // TODO: distinguish A/B
// datatype
DDataType,
YGridDesc_M_O,
DGridDesc_M,
BlockSize,
BlockSize,
64>;
// Argument
struct Argument : public BaseArgument
{
Argument(
const InputDataType* p_a_grid,
const InputDataType* p_b_grid,
ZDataType* p_z_grid,
const InputDataType* p_b1_grid,
const InputDataType* p_c_grid, // for dS
const LSEDataType* p_lse_grid,
DDataType* p_d_grid,
const InputDataType* p_ygrad_grid,
OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& d_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_z_grid_{p_z_grid},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
p_lse_grid_{p_lse_grid},
p_d_grid_{p_d_grid},
p_ygrad_grid_{p_ygrad_grid},
p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
d_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(d_gs_ms_lengths[NumDimG])},
k_grid_desc_n_k_{
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
ygrad_grid_desc_m0_o_m1_{DeviceOp::MakeYGradGridDescriptor_M0_O_M1(y_grid_desc_m_o_)},
// batch offsets
a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_g_n_k_{
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o_)},
d_y_grid_desc_mblock_mperblock_oblock_operblock_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]},
a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1],
b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1],
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())},
p_drop_{p_drop}
{
// TODO: implement bias addition
ignore = p_acc0_biases;
ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_,
y_grid_desc_m_o_))
{
y_grid_desc_mblock_mperblock_oblock_operblock_ =
GridwiseGemm::MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(
y_grid_desc_m_o_);
}
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
d_y_grid_desc_mblock_mperblock_oblock_operblock_ =
GridwiseYDotYGrad::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o_);
// Print();
m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
n_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
}
void Print() const
{
std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
<< a_grid_desc_g_m_k_.GetLength(I1) << ", "
<< a_grid_desc_g_m_k_.GetLength(I2) << '\n';
// a_grid_desc_g_m_k_.Print();
std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b_grid_desc_g_n_k_.Print();
std::cout << "b1_grid_desc_g_o_n_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1_grid_desc_g_n_k_.Print();
std::cout << "c_grid_desc_g_m_o_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
<< c_grid_desc_g_m_n_.GetLength(I1) << ", "
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n';
// c_grid_desc_g_m_n_.Print();
std::cout << "k_grid_desc_n_k_: " << k_grid_desc_n_k_.GetLength(I0) << ", "
<< k_grid_desc_n_k_.GetLength(I1) << '\n';
std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0)
<< ", " << ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", "
<< ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n';
}
// pointers
const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_;
ZDataType* p_z_grid_;
const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_;
DDataType* p_d_grid_;
const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_;
DGridDesc_M d_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_;
YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1_;
// batch offsets
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
typename GridwiseYDotYGrad::DefaultBlock2CTileMap d_block_2_ctile_map_;
typename GridwiseYDotYGrad::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
d_y_grid_desc_mblock_mperblock_oblock_operblock_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
// check C0 masking and padding
C0MatrixMask c0_matrix_mask_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
std::vector<index_t> a_mz_kz_strides_;
std::vector<index_t> b_nz_kz_strides_;
std::vector<index_t> b1_nz_kz_strides_;
std::vector<index_t> c_mz_gemm1nz_strides_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_drop_;
unsigned long long seed_;
unsigned long long offset_;
index_t m_raw_padded_;
index_t n_raw_padded_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!DeviceOp::IsSupportedArgument(arg))
{
throw std::runtime_error("wrong! unsupported argument");
}
float ave_time = 0;
{
const index_t grid_size =
(Deterministic
? 1
: arg.d_block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_)) *
arg.batch_count_;
std::cout << "grid_size: " << grid_size
<< "grid_size / arg.batch_count_: " << grid_size / arg.batch_count_
<< " arg.batch_count_: " << arg.batch_count_ << std::endl;
std::cout << "MPerBlock: " << MPerBlock << " Gemm1NPerBlock: " << Gemm1NPerBlock
<< std::endl;
std::cout << "arg.y_grid_desc_m_o_: {" << arg.y_grid_desc_m_o_.GetLength(I0) << ","
<< arg.y_grid_desc_m_o_.GetLength(I1) << "}" << std::endl;
auto launch_kernel = [&]() {
const auto kernel = kernel_batched_multihead_attention_backward_ydotygrad_v2<
GridwiseYDotYGrad,
InputDataType,
DDataType,
typename GridwiseYDotYGrad::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::DGridDesc_M,
typename GridwiseYDotYGrad::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_c_grid_,
arg.p_ygrad_grid_,
arg.p_d_grid_,
arg.d_y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.d_grid_desc_m_,
arg.d_block_2_ctile_map_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_);
};
ave_time = launch_kernel();
}
const index_t grid_size =
(Deterministic ? 1
: arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_)) *
arg.batch_count_;
// Gemm0_K
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm,
InputDataType,
OutputDataType,
ZDataType,
LSEDataType,
DDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M,
DeviceOp::DGridDesc_M,
DeviceOp::YGradGridDesc_M0_O_M1,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_z_grid_,
arg.p_b1_grid_,
arg.p_c_grid_,
arg.p_lse_grid_,
arg.p_d_grid_,
arg.p_ygrad_grid_,
arg.p_qgrad_grid_,
arg.p_kgrad_grid_,
arg.p_vgrad_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
arg.b1_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_,
arg.d_grid_desc_m_,
arg.ygrad_grid_desc_m0_o_m1_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_),
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_,
arg.p_drop_,
arg.seed_,
arg.offset_,
arg.m_raw_padded_,
arg.n_raw_padded_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time += launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time += launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
// TODO: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{
return false;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
const auto c_extent_lowest = Gemm1NzRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
const auto b_stride_lowest =
BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_kz_strides_[1] : arg.b1_nz_kz_strides_[0];
const auto c_stride_lowest =
arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_m_o_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const InputDataType* p_a,
const InputDataType* p_b,
ZDataType* p_z,
const InputDataType* p_b1,
const InputDataType* p_c,
const LSEDataType* p_lse,
DDataType* p_d_grid,
const InputDataType* p_ygrad_grid,
OutputDataType* p_qgrad_grid,
OutputDataType* p_kgrad_grid,
OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& d_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_a,
p_b,
p_z,
p_b1,
p_c,
p_lse,
p_d_grid,
p_ygrad_grid,
p_qgrad_grid,
p_kgrad_grid,
p_vgrad_grid,
p_acc0_biases,
p_acc1_biases,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths,
d_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
p_drop,
seeds};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
// FIXME: constness
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
void* p_z,
const void* p_b1,
const void* p_c,
const void* p_lse,
void* p_d_grid,
const void* p_ygrad_grid,
void* p_qgrad_grid,
void* p_kgrad_grid,
void* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& d_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override
{
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_a),
static_cast<const InputDataType*>(p_b),
static_cast<ZDataType*>(p_z),
static_cast<const InputDataType*>(p_b1),
static_cast<const InputDataType*>(p_c),
static_cast<const LSEDataType*>(p_lse),
static_cast<DDataType*>(p_d_grid),
static_cast<const InputDataType*>(p_ygrad_grid),
static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<OutputDataType*>(p_kgrad_grid),
static_cast<OutputDataType*>(p_vgrad_grid),
p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths,
d_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
p_drop,
seeds);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() // override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerBlock << ", "
<< Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", "
<< B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec) << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -1448,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1,
false>{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx_m, // mblock
make_multi_index(block_work_idx_m, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
......@@ -1511,14 +1511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace ck {
template <typename InputDataType,
typename OutputDataType,
typename ZDataType,
typename GemmDataType,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatLSE,
typename FloatD,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename SElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename QGridDesc_K0_M_K1,
typename KGridDesc_K0_N_K1,
typename KGridDesc_N_K,
typename ZGridDesc_M_N,
typename VGridDesc_O0_N_O1,
typename YGridDesc_M_O,
typename LSEGridDesc_M,
typename DGridDesc_M,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t B1K1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
bool PadN,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
static constexpr auto WaveSize = 64;
// K1 should be Number<...>
// Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// C desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const ZGridDesc_M_N& z_grid_desc_m_n)
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, M3, M4, M5)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6, 7, 8>{}, Sequence<1, 3, 5, 9>{}));
}
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(WaveSize / MPerXdl, MPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
template <typename AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2>
__host__ __device__ static constexpr auto GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
const AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2& acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2)
{
// acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 to a_src_thread_desc_k0_m_k1
// m0_m1_m2_m3 -> k0
// n0_n1_n2 -> m
// m4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
const auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
const auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
const auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
const auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
const auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
const auto m3 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
const auto m4 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
const auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
return transform_tensor_descriptor(
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2, m3)),
make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2)),
make_pass_through_transform(m4)),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 7>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = Gemm0NWaves;
constexpr index_t NWave = Gemm0MWaves;
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
template <typename Gemm2Param>
__host__ __device__ static constexpr auto GetA2BlockDescriptor_K0_M_K1()
{
return make_naive_tensor_descriptor(
make_tuple(Number<Gemm2Param::A_K0>{},
Number<Gemm2Param::Gemm2_M>{},
Number<Gemm2Param::A_K1>{}),
make_tuple(Number<Gemm2Param::Gemm2_M + Gemm2Param::A_LdsPad>{} *
Number<Gemm2Param::A_K1>{},
Number<Gemm2Param::A_K1>{},
I1));
}
__host__ __device__ static constexpr bool
CheckValidity(const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const YGridDesc_M_O& y_grid_desc_m_o)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K)
{
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
if(!(M == y_grid_desc_m_o.GetLength(I0) && Gemm1N == y_grid_desc_m_o.GetLength(I1)))
{
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0))
{
return false;
}
// check gemm1 gridwise gemm pipeline
if(!(NPerBlock % Gemm1KPerBlock == 0))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr auto
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(const YGridDesc_M_O& y_grid_desc_m_o)
{
const auto M = y_grid_desc_m_o.GetLength(I0);
const auto O = y_grid_desc_m_o.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto OBlock = O / Gemm1NPerBlock;
const auto y_grid_desc_mblock_mperblock_oblock_operblock = transform_tensor_descriptor(
y_grid_desc_m_o,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(OBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return y_grid_desc_mblock_mperblock_oblock_operblock;
}
template <typename SrcBlockwiseGemm>
__host__ __device__ static constexpr auto
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4(const LSEGridDesc_M& lse_grid_desc_m)
{
const index_t M = lse_grid_desc_m.GetLength(I0);
const index_t MBlock = M / MPerBlock;
constexpr auto SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
SrcBlockwiseGemm::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// M0 MXdlPerWave, M1 MWave, M2 num_groups_per_blk, M3 num_input_blks, M4 group_size
const auto M0 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I0);
const auto M1 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I2);
const auto M2 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I4);
const auto M3 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I5);
const auto M4 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I6);
const auto lse_grid_desc_mb_m0_m1_m2_m3_m4 = transform_tensor_descriptor(
lse_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, M0, M1, M2, M3, M4))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}));
return lse_grid_desc_mb_m0_m1_m2_m3_m4;
}
__device__ static auto MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1)
{
const auto O0 = k_grid_desc_k0_n_k1.GetLength(I0);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto O1 = k_grid_desc_k0_n_k1.GetLength(I2);
const auto O = O0 * O1;
const auto NBlock = N / NPerBlock;
const auto OBlock = O / Gemm1NPerBlock;
const auto k_grid_desc_n_o = transform_tensor_descriptor(
k_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(N),
make_merge_transform_v3_division_mod(make_tuple(O0, O1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
k_grid_desc_n_o,
make_tuple(make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{})),
make_unmerge_transform(make_tuple(OBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
}
__device__ static auto MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1)
{
const auto O0 = v_grid_desc_o0_n_o1.GetLength(I0);
const auto N = v_grid_desc_o0_n_o1.GetLength(I1);
const auto O1 = v_grid_desc_o0_n_o1.GetLength(I2);
const auto O = O0 * O1;
const auto NBlock = N / NPerBlock;
const auto OBlock = O / Gemm1NPerBlock;
const auto v_grid_desc_n_o = transform_tensor_descriptor(
v_grid_desc_o0_n_o1,
make_tuple(make_pass_through_transform(N),
make_merge_transform_v3_division_mod(make_tuple(O0, O1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
v_grid_desc_n_o,
make_tuple(make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{})),
make_unmerge_transform(make_tuple(OBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
}
__device__ static auto MakeQGradGridDesc_M_K(const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1)
{
const auto K_K0 = q_grid_desc_k0_m_k1.GetLength(I0);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto K_K1 = q_grid_desc_k0_m_k1.GetLength(I2);
return transform_tensor_descriptor(
q_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(M),
make_merge_transform_v3_division_mod(make_tuple(K_K0, K_K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const KGridDesc_N_K& k_grid_desc_n_k)
{
return BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, KPerBlock, KGridDesc_N_K>(
k_grid_desc_n_k);
}
using YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock = remove_cvref_t<decltype(
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(YGridDesc_M_O{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(KGridDesc_N_K{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(ZGridDesc_M_N{}))>;
// Q / K / V / dY
struct GemmBlockwiseCopy
{
// Q matrix in LDS memory, dst of blockwise copy
static constexpr auto q_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// K matrix in LDS memory, dst of blockwise copy
static constexpr auto k_block_desc_k0_n_k1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// V matrix in LDS memory, dst of blockwise copy
static constexpr auto v_block_desc_k0_n_k1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// dY matrix in LDS memory, dst of blockwise copy
static constexpr auto ygrad_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
template <typename GridDesc_K0_M_K1>
using QBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(q_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
template <typename GridDesc_K0_N_K1>
using KBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(k_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
template <typename GridDesc_K0_N_K1>
using VBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(v_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
template <typename GridDesc_K0_M_K1>
using YGradBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(ygrad_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr auto gemm_tile_q_block_slice_copy_step =
make_multi_index(0, -MPerBlock, 0);
static constexpr auto gemm_tile_ygrad_block_slice_copy_step =
make_multi_index(0, -MPerBlock, 0);
};
// S / dP Gemm (type 1 rcc)
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
static constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>;
};
// dV / dK Gemm (type 2 rrr)
template <typename ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename ASrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2>
struct Gemm1
{
private:
static constexpr auto m0 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I0);
static constexpr auto n0 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I1);
static constexpr auto m1 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I2);
static constexpr auto n1 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I3);
static constexpr auto m2 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I4);
static constexpr auto m3 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I5);
static constexpr auto m4 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I6);
static constexpr auto n2 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I7);
// M2 num_groups_per_blk, M3 num_input_blks, M4 group_size
static constexpr auto M3 = ASrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I5);
public:
static constexpr auto AThreadSliceLength_K0 = Number<Gemm1KPerBlock / m4 / M3>{};
static constexpr auto AThreadSliceLength_M = Number<n0 * n1 * n2>{};
static constexpr auto AThreadSliceLength_K1 = Number<m4>{};
// A source matrix layout in AccVGPR
static constexpr auto a_src_thread_desc_k0_m_k1 =
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{});
// A matrix in VGPR memory, dst of AccVGPR-to-VGPR copy
static constexpr auto a_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed(
make_tuple(AThreadSliceLength_K0, AThreadSliceLength_M, AThreadSliceLength_K1));
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bn0_k_bn1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, 1, 1>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm1NXdlPerWave, 1, 1>(
BBlockDesc_BK0_N_BK1{});
}
static constexpr auto ASrcScalarPerVector = m4;
using AThreadSliceLengths_K0_M_K1 = decltype(a_thread_desc_k0_m_k1.GetLengths());
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy =
ThreadwiseTensorSliceTransfer_StaticToStatic<FloatGemmAcc,
GemmDataType,
decltype(a_src_thread_desc_k0_m_k1),
decltype(a_thread_desc_k0_m_k1),
ElementwiseOp,
AThreadSliceLengths_K0_M_K1,
Sequence<1, 0, 2>,
2,
ASrcScalarPerVector>;
// for a_block_slice_copy_step to be able to address static buffers, it MUST be a
// tuple-based container as well as containing ONLY integral constants
static constexpr auto a_block_slice_copy_step = make_tuple(AThreadSliceLength_K0, I0, I0);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack =
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
static constexpr index_t GemmMWave = Gemm0NWaves; // 4 // 4
static constexpr index_t GemmNWave = Gemm0MWaves; // 1 // 1
static constexpr index_t GemmMRepeat = NXdlPerWave; // 1 // 1
static constexpr index_t GemmNRepeat = Gemm1NXdlPerWave; // 1 // 2
static constexpr index_t GemmKLoop = MPerBlock / Gemm1KPerBlock; // 128/32=4 // 64/32=2
static constexpr index_t B_K3 = GemmKPack; // 4 // 4
static constexpr index_t B_K2 = M3; // 2 // 2
static constexpr index_t B_K1 = Gemm1KPerBlock / B_K2 / B_K3; // 4 // 4
static constexpr index_t B_K0 = GemmKLoop; // 4 // 2
__host__ __device__ static constexpr auto MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3()
{
const auto N0_ = b_block_desc_bn0_k_bn1.GetLength(I0);
const auto K_ = b_block_desc_bn0_k_bn1.GetLength(I1);
const auto N1_ = b_block_desc_bn0_k_bn1.GetLength(I2);
constexpr auto b_block_desc_n_k = transform_tensor_descriptor( //(32, 128) //(64, 64)
b_block_desc_bn0_k_bn1,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(N0_, N1_)), //(4, 8) //(8, 8)
make_pass_through_transform(K_)), // 128 // 64
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
b_block_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(
GemmNRepeat, GemmNWave, NPerXdl)), //(1, 1, 32) //(2, 1, 32)
make_unmerge_transform(
make_tuple(B_K0, B_K1, B_K2, B_K3))), //(4, 4, 2, 4) //(2, 4, 2, 4)
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}));
}
static constexpr auto b_block_desc_n0_n1_n2_k0_k1_k2_k3 =
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3();
using BThreadSlice_N0_N1_N2_K0_K1_K2_K3 = Sequence<GemmNRepeat, 1, 1, 1, B_K1, 1, B_K3>;
static constexpr auto b_thread_desc_n0_n1_n2_k0_k1_k2_k3 =
make_naive_tensor_descriptor_packed(
make_tuple(Number<GemmNRepeat>{}, I1, I1, I1, Number<B_K1>{}, I1, Number<B_K3>{}));
__host__ __device__ static constexpr auto MakeBThreadDesc_K0_N_K1()
{
constexpr auto b_thread_desc_n_k = transform_tensor_descriptor(
b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<GemmNRepeat>{}, I1, I1)),
make_merge_transform_v3_division_mod(
make_tuple(I1, Number<B_K1>{}, I1, Number<B_K3>{}))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
b_thread_desc_n_k,
make_tuple(make_pass_through_transform(Number<GemmNRepeat>{}),
make_unmerge_transform(make_tuple(Number<B_K1>{}, Number<B_K3>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
}
static constexpr auto b_thread_desc_k0_n_k1 = MakeBThreadDesc_K0_N_K1();
using BBlockwiseCopy =
ThreadwiseTensorSliceTransfer_v2<GemmDataType,
GemmDataType,
decltype(b_block_desc_n0_n1_n2_k0_k1_k2_k3),
decltype(b_thread_desc_n0_n1_n2_k0_k1_k2_k3),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
1,
1,
true>;
static constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, 1, 0, 0, 0);
static constexpr auto b_block_reset_copy_step = make_multi_index(0, 0, 0, -B_K0, 0, 0, 0);
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_thread_desc_k0_m_k1),
decltype(b_thread_desc_k0_n_k1),
decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a_thread_desc_k0_m_k1)),
decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b_thread_desc_k0_n_k1)),
NPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
MPerXdl,
NPerXdl,
NXdlPerWave,
Gemm1NXdlPerWave,
GemmKPack,
true, // TransposeC
GemmKPack, // AMmaKStride
GemmKPack>;
};
// dQ Gemm (type 3 crr)
// Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n
template <index_t Sum_K_ = NPerXdl * 2>
struct Gemm2Params_
{
static constexpr index_t Gemm2_M = MPerBlock; // 128 // 64
static constexpr index_t Gemm2_K = NPerBlock; // 128 // 128
static constexpr index_t Gemm2_N = Gemm1NPerBlock; // 32 // 64
static constexpr index_t Sum_K = Sum_K_;
static constexpr index_t A_K1 = 8; // dS will be row-major
static constexpr index_t A_K0 = Sum_K / A_K1;
static constexpr index_t A_LdsPad = 0; // how many multiples of K1 per M * K1 elements
static_assert(Sum_K % NPerXdl == 0, "");
static constexpr index_t GemmNWave = Gemm2_N / Gemm2NXdlPerWave / NPerXdl; // 1 // 2
static constexpr index_t GemmMWave = BlockSize / get_warp_size() / GemmNWave; // 4 // 2
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave; // 1 // 1
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl; // 1 // 1
static constexpr index_t GemmKLoop = Gemm2_K / Sum_K; // 2 // 2
static constexpr index_t GemmKPack =
math::max(A_K1, MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t B_K2 =
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}.K0PerXdlops; // 2
static constexpr index_t B_K1 = Sum_K / B_K2 / B_K3; // 4
static constexpr index_t B_K0 = GemmKLoop; // 2
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1_M2_K2()
{
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr index_t k = Gemm2Params::Sum_K - 1;
constexpr index_t k2 = k % NPerXdl;
constexpr index_t k1 = k / NPerXdl % Gemm0NWaves;
constexpr index_t k0 = k / NPerXdl / Gemm0NWaves % NXdlPerWave;
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr index_t m = Gemm2Params::Gemm2_M - 1;
constexpr index_t m2 = m % MPerXdl;
constexpr index_t m1 = m / MPerXdl % Gemm0MWaves;
constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave;
// assume 256 decomposed into 2 x 4 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
// 1d idx (256 - 1) -> 3d idx 1, 3, 31 -> 3d dim 2 x 4 x 32
return Sequence<m0, k0, m1, k1, m2, k2>{} + Sequence<1, 1, 1, 1, 1, 1>{};
}
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1()
{
return generate_sequence_v2(
[](auto I) { return GetABlockSliceLengths_M0_K0_M1_K1_M2_K2().At(I); },
Number<4>{});
}
using ABlockSliceLengths_M0_K0_M1_K1 =
decltype(GetABlockSliceLengths_M0_K0_M1_K1()); //(2, 1, 1, 2) //(4, 1, 1, 2)
};
using Gemm2Params = Gemm2Params_<>; // tune later
// dQ Gemm (type 3 crr)
template <typename Gemm2Params, typename ASrcBlockwiseGemm>
struct Gemm2
{
private:
static constexpr auto a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
ASrcBlockwiseGemm::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
static constexpr auto M0 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); // repeat
static constexpr auto N0 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
static constexpr auto M1 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); // wave
static constexpr auto N1 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
static constexpr auto M2 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); // xdl
static constexpr auto M3 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
static constexpr auto M4 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
static constexpr auto N2 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
public:
// A source matrix layout in VGPR, src of VGPR-to-LDS copy
static constexpr auto a_src_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
ASrcBlockwiseGemm::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
// // B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_n0_k_n1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
template <typename ABlockDesc_K0_M_K1>
__host__ __device__ static constexpr auto
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_K0_M_K1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm2Params::GemmMRepeat,
Gemm2Params::GemmMWave,
MPerXdl>(ABlockDesc_K0_M_K1{});
}
template <typename BBlockDesc_K0_N_K1>
__host__ __device__ static constexpr auto
MakeGemm2BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_K0_N_K1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm2Params::GemmNRepeat, 1, 1>(
BBlockDesc_K0_N_K1{});
}
__host__ __device__ static constexpr auto MakeABlockDesc_M0_K0_M1_K1_M2_M3_M4_K2()
{
const auto K0_ = a_block_desc_k0_m_k1.GetLength(I0);
const auto M_ = a_block_desc_k0_m_k1.GetLength(I1);
const auto K1_ = a_block_desc_k0_m_k1.GetLength(I2);
const auto a_block_desc_m_k = transform_tensor_descriptor(
a_block_desc_k0_m_k1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0_, K1_)),
make_pass_through_transform(M_)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
return transform_tensor_descriptor(
a_block_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(I1, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(I1, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
}
// Note: we will perform sub-workgroup VGPR-to-LDS copy to save LDS space, therefore the
// destination coordinate can overlap between wavefronts in a workgroup as seen in the mod
// operation before returning the values
__host__ __device__ static auto MakeAThreadOriginOnBlock_M0_K0_M1_K1_M2_M3_M4_K2()
{
const auto a_thread_origin_on_block_idx =
ASrcBlockwiseGemm::CalculateCThreadOriginDataIndex8D(I0, I0, I0, I0);
constexpr auto a_block_slice_lengths_m0_k0_m1_k1 =
typename Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1{}; // mrepeat, nrepeat,
// mwaves, nwaves,
return make_tuple(
a_thread_origin_on_block_idx[I0], // mrepeat
a_thread_origin_on_block_idx[I1], // nrepeat
a_thread_origin_on_block_idx[I2] % a_block_slice_lengths_m0_k0_m1_k1[I2], // mwave
a_thread_origin_on_block_idx[I3] % a_block_slice_lengths_m0_k0_m1_k1[I3], // nwave
a_thread_origin_on_block_idx[I4], // xdlops
a_thread_origin_on_block_idx[I5],
a_thread_origin_on_block_idx[I6],
a_thread_origin_on_block_idx[I7]);
}
static constexpr auto a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2 =
MakeABlockDesc_M0_K0_M1_K1_M2_M3_M4_K2();
using ASrcBlockSliceWindowIterator =
SpaceFillingCurve<Sequence<M0, N0, M1, N1>,
Sequence<0, 1, 2, 3>,
typename Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1,
false>;
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
GemmDataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2),
ElementwiseOp,
Sequence<Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I0), // ThreadSliceLengths
Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I1),
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>;
__host__ __device__ static constexpr auto MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3()
{
const auto N0_ = b_block_desc_n0_k_n1.GetLength(I0);
const auto K_ = b_block_desc_n0_k_n1.GetLength(I1);
const auto N1_ = b_block_desc_n0_k_n1.GetLength(I2);
constexpr auto b_block_desc_n_k = transform_tensor_descriptor( //(32, 128) //(64, 128)
b_block_desc_n0_k_n1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(N0_, N1_)), //(4, 8) //(8, 8)
make_pass_through_transform(K_)), // 128 // 128
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
b_block_desc_n_k,
make_tuple(
make_unmerge_transform(make_tuple(Gemm2Params::GemmNRepeat,
Gemm2Params::GemmNWave,
NPerXdl)), //(1, 1, 32) //(1, 2, 32)
make_unmerge_transform(
make_tuple(Gemm2Params::B_K0,
Gemm2Params::B_K1,
Gemm2Params::B_K2,
Gemm2Params::B_K3))), //(2, 4, 2, 8) //(2, 4, 2, 8)
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}));
}
static constexpr auto b_block_desc_n0_n1_n2_k0_k1_k2_k3 =
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3();
using BThreadSlice_N0_N1_N2_K0_K1_K2_K3 =
Sequence<Gemm2Params::GemmNRepeat, 1, 1, 1, Gemm2Params::B_K1, 1, Gemm2Params::B_K3>;
static constexpr auto b_thread_desc_n0_n1_n2_k0_k1_k2_k3 =
make_naive_tensor_descriptor_packed(make_tuple(Number<Gemm2Params::GemmNRepeat>{},
I1,
I1,
I1,
Number<Gemm2Params::B_K1>{},
I1,
Number<Gemm2Params::B_K3>{}));
__host__ __device__ static constexpr auto MakeBThreadDesc_K0_N_K1()
{
constexpr auto b_thread_desc_n_k = transform_tensor_descriptor(
b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<Gemm2Params::GemmNRepeat>{}, I1, I1)),
make_merge_transform_v3_division_mod(make_tuple(
I1, Number<Gemm2Params::B_K1>{}, I1, Number<Gemm2Params::B_K3>{}))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
b_thread_desc_n_k,
make_tuple(make_pass_through_transform(Number<Gemm2Params::GemmNRepeat>{}),
make_unmerge_transform(make_tuple(Number<Gemm2Params::B_K1>{},
Number<Gemm2Params::B_K3>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
}
static constexpr auto b_thread_desc_k0_n_k1 = MakeBThreadDesc_K0_N_K1();
using BBlockwiseCopy =
ThreadwiseTensorSliceTransfer_v2<GemmDataType,
GemmDataType,
decltype(b_block_desc_n0_n1_n2_k0_k1_k2_k3),
decltype(b_thread_desc_n0_n1_n2_k0_k1_k2_k3),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
1,
1,
true>;
static constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, 1, 0, 0, 0);
static constexpr auto b_block_reset_copy_step =
make_multi_index(0, 0, 0, -Gemm2Params::B_K0, 0, 0, 0);
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_thread_desc_k0_n_k1),
decltype(MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_k0_m_k1)),
decltype(MakeGemm2BMmaTileDescriptor_N0_N1_N2_K(b_thread_desc_k0_n_k1)),
MPerBlock,
Gemm1NPerBlock,
Gemm2Params::Sum_K,
MPerXdl,
NPerXdl,
Gemm2Params::GemmMRepeat,
Gemm2Params::GemmNRepeat,
Gemm2Params::GemmKPack,
true, // TransposeC
Gemm2Params::GemmKPack *
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, Gemm2Params::GemmKPack, false>{}
.K0PerXdlops,
Gemm2Params::GemmKPack>;
static constexpr auto c_block_slice_copy_step =
make_multi_index(-Gemm2Params::GemmMRepeat, 0, 0, 0, 0, 0, 0, 0);
template <typename CGradDesc_M_N>
__host__ __device__ static auto
MakeCGridDesc_M0_N0_M1_N1_M2_N2_N3_N4(const CGradDesc_M_N& c_grid_desc_m_n)
{
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(I1, Gemm2Params::GemmMWave, MPerXdl)),
make_unmerge_transform(make_tuple(I1, Gemm2Params::GemmNWave, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
const auto c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
BlockwiseGemm{}.xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(
c_grid_desc_m0_n0_m1_n1_m2_n2);
return c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4;
}
static constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
BlockwiseGemm::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
__host__ __device__ static auto GetCThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4()
{
return to_multi_index(BlockwiseGemm::CalculateCThreadOriginDataIndex8D(I0, I0, I0, I0));
}
template <typename CGridDesc_M0_N0_M1_N1_M2_N2_N3_N4,
typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
OutputDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
CGridDesc_M0_N0_M1_N1_M2_N2_N3_N4,
ElementwiseOp, // CElementwiseOperation
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLengths()), // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim
2, // ScalarPerVector
InMemoryDataOperationEnum::AtomicAdd, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
};
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVector = 16 / sizeof(InputDataType);
static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVector>{};
static constexpr auto ThreadSliceLength_M =
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
// dY matrix in LDS memory, dst of blockwise copy
static constexpr auto ygrad_block_desc_o0_m_o1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
__host__ __device__ static constexpr auto MakeYGradBlockDesc_M_O()
{
const auto O0_ = ygrad_block_desc_o0_m_o1.GetLength(I0);
const auto M_ = ygrad_block_desc_o0_m_o1.GetLength(I1);
const auto O1_ = ygrad_block_desc_o0_m_o1.GetLength(I2);
static_assert(O0_ * O1_ == BlockSliceLength_O_, "");
static_assert(M_ == BlockSliceLength_M_, "");
return transform_tensor_descriptor( //(128, 64)
ygrad_block_desc_o0_m_o1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(O0_, O1_)), //(8, 8)
make_pass_through_transform(M_)), // 128
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
}
static constexpr auto ygrad_block_desc_m_o = MakeYGradBlockDesc_M_O();
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatGemmAcc,
ThreadSliceLength_M * ThreadSliceLength_O,
true>;
using DstBufType =
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, ThreadSliceLength_M, true>;
};
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>;
struct SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
static constexpr auto q_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto k_block_desc_k0_n_k1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto v_block_desc_k0_n_k1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto ygrad_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto p_slash_sgrad_block_desc_k0_m_k1 =
GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto q_block_space_size_aligned =
math::integer_least_multiple(q_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto v_block_space_size_aligned =
math::integer_least_multiple(v_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_slash_sgrad_block_space_size_aligned = math::integer_least_multiple(
p_slash_sgrad_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto v_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto ygrad_block_space_offset =
k_block_space_size_aligned.value + v_block_space_size_aligned.value;
static constexpr auto q_block_space_offset = k_block_space_size_aligned.value +
v_block_space_size_aligned.value +
ygrad_block_space_size_aligned.value;
static constexpr auto p_slash_sgrad_block_space_offset =
k_block_space_size_aligned.value + v_block_space_size_aligned.value +
ygrad_block_space_size_aligned.value + q_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset =
(k_block_space_size_aligned.value + v_block_space_size_aligned.value +
ygrad_block_space_size_aligned.value + q_block_space_size_aligned.value) *
sizeof(GemmDataType) / sizeof(FloatGemmAcc);
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t p_slash_sgrad_bytes_end =
(SharedMemTrait::p_slash_sgrad_block_space_offset +
SharedMemTrait::p_slash_sgrad_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(p_slash_sgrad_bytes_end, softmax_bytes_end, c_block_bytes_end);
}
template <bool HasMainKBlockLoop,
typename Block2CTileMap,
typename C0MatrixMask,
typename YGradGridDesc_O0_M_O1>
__device__ static void Run(const InputDataType* __restrict__ p_q_grid,
const InputDataType* __restrict__ p_k_grid,
ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_v_grid,
const InputDataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid,
const FloatD* __restrict__ p_d_grid,
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const SElementwiseOperation& s_element_op,
const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
const LSEGridDesc_M& lse_grid_desc_m,
const DGridDesc_M& d_grid_desc_m,
const YGradGridDesc_O0_M_O1& ygrad_grid_desc_o0_m_o1,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
const float p_drop,
ck::philox& ph,
const index_t z_random_matrix_offset,
const index_t raw_n_padded,
const index_t block_idx_n)
{
ignore = p_y_grid;
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits =
__builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout);
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto k_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_k_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
const auto v_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_v_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize());
const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
const auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_m.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_o0_m_o1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize());
auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_kgrad_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
// divide block work by [N, K]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t block_work_idx_n = Deterministic ? block_idx_n : block_work_idx[I0];
// HACK: this force n_block_data_idx_on_grid into SGPR
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx_n * NPerBlock);
const index_t num_gemm0_m_block_outer_loop = q_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = MPerBlock / Gemm1KPerBlock;
// 6 GEMM operations are categorized into 3 buckets. SizeK == SizeO == head_dim
// S_MNK / dP_MNO Gemm (Gemm0 rcc)
// dV_NOM / dK_NKM Gemm (Gemm1 rrr)
// Y_MON / dQ_MKN Gemm (Gemm2 crr)
// LDS allocation for Q / K / V / dY
auto q_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::q_block_space_offset,
GemmBlockwiseCopy::q_block_desc_k0_m_k1.GetElementSpaceSize());
auto k_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::k_block_space_offset,
GemmBlockwiseCopy::k_block_desc_k0_n_k1.GetElementSpaceSize());
auto v_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::v_block_space_offset,
GemmBlockwiseCopy::v_block_desc_k0_n_k1.GetElementSpaceSize());
auto ygrad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::ygrad_block_space_offset,
GemmBlockwiseCopy::ygrad_block_desc_k0_m_k1.GetElementSpaceSize());
// Q matrix blockwise copy
auto gemm_tile_q_blockwise_copy =
typename GemmBlockwiseCopy::template QBlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>(
q_grid_desc_k0_m_k1,
make_multi_index(0,
MPerBlock * (num_gemm0_m_block_outer_loop - 1),
0), // will loop over GemmM dimension
a_element_op,
GemmBlockwiseCopy::q_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// K matrix blockwise copy
auto gemm_tile_k_blockwise_copy =
typename GemmBlockwiseCopy::template KBlockwiseCopy<decltype(k_grid_desc_k0_n_k1)>(
k_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
GemmBlockwiseCopy::k_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// V matrix blockwise copy
auto gemm_tile_v_blockwise_copy =
typename GemmBlockwiseCopy::template VBlockwiseCopy<decltype(v_grid_desc_o0_n_o1)>(
v_grid_desc_o0_n_o1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b1_element_op,
GemmBlockwiseCopy::v_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dY matrix blockwise copy
auto gemm_tile_ygrad_blockwise_copy =
typename GemmBlockwiseCopy::template YGradBlockwiseCopy<decltype(
ygrad_grid_desc_o0_m_o1)>(
ygrad_grid_desc_o0_m_o1,
make_multi_index(0,
MPerBlock * (num_gemm0_m_block_outer_loop - 1),
0), // will loop over GemmM dimension
a_element_op,
GemmBlockwiseCopy::ygrad_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
//
// set up S / dP Gemm (type 1 rcc)
//
// S: blockwise gemm
auto s_blockwise_gemm = typename Gemm0::BlockwiseGemm{}; // TransposeC
auto s_slash_p_thread_buf = s_blockwise_gemm.GetCThreadBuffer();
// dP: blockwise gemm
// we need separate blockwise gemm object because we need separate thread buffer
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
//
// set up dV / dK Gemm (type 2 rrr)
//
using Gemm1 =
Gemm1<decltype(s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()),
decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2())>;
// Gemm1: VGPR allocation for A and B
auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize());
auto gemm1_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm1::b_thread_desc_n0_n1_n2_k0_k1_k2_k3.GetElementSpaceSize());
// dV: A matrix blockwise copy
auto vgrad_gemm_tile_p_blockwise_copy =
typename Gemm1::template ABlockwiseCopy<tensor_operation::element_wise::Relu>{
tensor_operation::element_wise::Relu{}}; // relu(P-dropped)
// dV: blockwise gemm
auto vgrad_blockwise_gemm =
typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0), make_tuple(0, 0, 0, 0)};
// dV: B matrix blockwise copy
auto ygrad_thread_origin = vgrad_blockwise_gemm.CalculateBThreadOriginDataIndex();
// dV: B matrix LDS-to-VGPR blockwise copy
auto vgrad_gemm_tile_ygrad_blockwise_copy = typename Gemm1::BBlockwiseCopy{
Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
make_multi_index(0, // nrepeat
ygrad_thread_origin[I1], // nwave
ygrad_thread_origin[I2], // nperxdl
0, // k0
0, // k1
ygrad_thread_origin[I3] / Gemm1::GemmKPack, // k2
0)};
auto vgrad_thread_buf = vgrad_blockwise_gemm.GetCThreadBuffer();
// dV: transform input and output tensor descriptors
auto vgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(v_grid_desc_o0_n_o1);
// dK: A matrix blockwise copy
auto kgrad_gemm_tile_sgrad_blockwise_copy =
typename Gemm1::template ABlockwiseCopy<tensor_operation::element_wise::PassThrough>{
tensor_operation::element_wise::PassThrough{}};
// dK: blockwise gemm
auto kgrad_blockwise_gemm =
typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0), make_tuple(0, 0, 0, 0)};
// dK: B matrix blockwise copy
auto q_thread_origin = kgrad_blockwise_gemm.CalculateBThreadOriginDataIndex();
// dK: B matrix LDS-to-VGPR blockwise copy
auto kgrad_gemm_tile_q_blockwise_copy = typename Gemm1::BBlockwiseCopy{
Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
make_multi_index(0, // nrepeat
q_thread_origin[I1], // nwave
q_thread_origin[I2], // nperxdl
0, // k0
0, // k1
q_thread_origin[I3] / Gemm1::GemmKPack, // k2
0)};
auto kgrad_thread_buf = kgrad_blockwise_gemm.GetCThreadBuffer();
// dK: transform input and output tensor descriptors
auto kgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(k_grid_desc_k0_n_k1);
//
// set up dQ Gemm (type 3 crr)
//
using Gemm2 = Gemm2<Gemm2Params, decltype(pgrad_blockwise_gemm)>;
// Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::p_slash_sgrad_block_space_offset,
Gemm2::a_block_desc_k0_m_k1.GetElementSpaceSize());
auto gemm2_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3.GetElementSpaceSize());
// dQ: A matrix VGPR-to-LDS blockwise copy
auto qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds =
typename Gemm2::template ABlockwiseCopy<tensor_operation::element_wise::PassThrough>{
Gemm2::a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2,
Gemm2::MakeAThreadOriginOnBlock_M0_K0_M1_K1_M2_M3_M4_K2(),
tensor_operation::element_wise::PassThrough{}};
// dQ: blockwise gemm
auto qgrad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
qgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
auto k_thread_origin = qgrad_blockwise_gemm.CalculateBThreadOriginDataIndex();
// dQ: B matrix LDS-to-VGPR blockwise copy
auto qgrad_gemm_tile_k_blockwise_copy = typename Gemm2::BBlockwiseCopy{
Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
make_multi_index(0, // nrepeat
k_thread_origin[I1], // nwave
k_thread_origin[I2], // nperxdl
0, // k0
0, // k1
k_thread_origin[I3] / Gemm2Params::GemmKPack, // k2
0)}; // k3
auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer();
// dQ: transform output tensor descriptors
const auto qgrad_grid_desc_m_k = MakeQGradGridDesc_M_K(q_grid_desc_k0_m_k1);
const auto qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4 =
Gemm2::MakeCGridDesc_M0_N0_M1_N1_M2_N2_N3_N4(qgrad_grid_desc_m_k);
// dQ: C VGPR-to-global copy
const auto qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4 =
Gemm2::GetCThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4() +
make_multi_index((num_gemm0_m_block_outer_loop - 1) * Gemm2Params::GemmMRepeat,
I0,
I0,
I0,
I0,
I0,
I0,
I0);
auto qgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<
decltype(qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4),
decltype(scale_rp_dropout)>(qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4,
qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4,
scale_rp_dropout);
//
// Blockwise softmax
//
// get acc0 8D thread cluster
constexpr auto thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2 =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths() /
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I0);
constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I1);
constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I2);
constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I3);
constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I4);
constexpr auto tm3 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I5);
constexpr auto tm4 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I6);
constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I7);
// get acc0 thread map
constexpr auto n0_m_n1_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(tn0 * tn1, tn2)),
make_pass_through_transform(I1)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto threadid_to_n0_m_n1_adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_merge_transform(make_tuple(tn0 * tn1, tm0 * tm1 * tm2 * tm3 * tm4, tn2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto threadid_to_m_n_thread_cluster_adaptor =
chain_tensor_adaptors(n0_m_n1_to_m_n_adaptor, threadid_to_n0_m_n1_adaptor);
// get acc0 2D thread cluster & 2D thread slice
constexpr auto thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto m0 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
constexpr auto n0 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
constexpr auto m1 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
constexpr auto n1 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
constexpr auto m2 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
constexpr auto m3 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
constexpr auto m4 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto n2 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(tm0 * tm1 * tm2 * tm3 * tm4, tn0 * tn1 * tn2));
constexpr auto thread_slice_desc_m_n =
make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2 * m3 * m4, n0 * n1 * n2));
auto blockwise_softmax = BlockwiseSoftmax<BlockSize,
FloatGemmAcc,
decltype(threadid_to_m_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, rp_dropout};
auto lse_grid_desc_mb_m0_m1_m2_m3_m4 =
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4<decltype(s_blockwise_gemm)>(lse_grid_desc_m);
constexpr auto lse_thread_desc_mb_m0_m1_m2_m3_m4 =
make_naive_tensor_descriptor_packed(make_tuple(I1, m0, m1, m2, m3, m4));
auto lse_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatLSE>(
lse_thread_desc_mb_m0_m1_m2_m3_m4.GetElementSpaceSize());
auto acc0_thread_origin = s_blockwise_gemm.CalculateCThreadOriginDataIndex8D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
auto lse_thread_copy_global_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatLSE,
FloatLSE,
decltype(lse_grid_desc_mb_m0_m1_m2_m3_m4),
decltype(lse_thread_desc_mb_m0_m1_m2_m3_m4),
Sequence<1, m0, m1, m2, m3, m4>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1,
true /* ResetCoordAfterRun */>{
lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(num_gemm0_m_block_outer_loop - 1, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4], // mperxdl
acc0_thread_origin[I5],
acc0_thread_origin[I6])};
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MGroupNum
m3, // MInputNum
m4, // RegisterNum
n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr,
ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
m3, // NGroupNum
m4, // NInputNum
n2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(num_gemm0_m_block_outer_loop - 1, // MBlockId
block_work_idx_n, // NBlockId
0, // MRepeat
0, // NRepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
0, // MPerXdl
wave_m_n_id[I0], //
0, //
wave_m_n_id[I1]), // NPerXdl
tensor_operation::element_wise::PassThrough{}};
//
// set up Y dot dY
//
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr auto p_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto P_M0 = p_block_lengths[I0]; // repeats
constexpr auto P_M1 = p_block_lengths[I2]; // waves
constexpr auto P_M2 = p_block_lengths[I4]; // xdl
constexpr auto P_M3 = p_block_lengths[I5];
constexpr auto P_M4 = p_block_lengths[I6];
constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple(
I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto ygrad_thread_desc_m_o = make_naive_tensor_descriptor_packed(
make_tuple(YDotYGrad_M_O::ThreadSliceLength_M, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc =
make_cluster_descriptor(Sequence<I1,
YDotYGrad_M_O::ThreadClusterLength_M,
I1,
YDotYGrad_M_O::ThreadClusterLength_O>{},
Sequence<0, 1, 2, 3>{});
const auto y_thread_cluster_idx =
y_thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
constexpr auto ygrad_thread_cluster_desc = make_cluster_descriptor(
Sequence<YDotYGrad_M_O::ThreadClusterLength_M, YDotYGrad_M_O::ThreadClusterLength_O>{},
Sequence<0, 1>{});
const auto ygrad_thread_cluster_idx = ygrad_thread_cluster_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
const auto ygrad_thread_data_on_block_idx =
ygrad_thread_cluster_idx * ygrad_thread_desc_m_o.GetLengths();
const auto y_thread_data_on_grid_idx =
make_multi_index(num_gemm0_m_block_outer_loop - 1, I0, I0, I0) +
y_thread_data_on_block_idx;
// performs for y
auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
InputDataType,
FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>,
3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx);
// performs for ygrad
auto ygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
GemmDataType,
FloatGemmAcc,
decltype(YDotYGrad_M_O::ygrad_block_desc_m_o),
decltype(ygrad_thread_desc_m_o),
decltype(ygrad_thread_desc_m_o.GetLengths()),
Sequence<0, 1>,
1, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */>(YDotYGrad_M_O::ygrad_block_desc_m_o,
ygrad_thread_data_on_block_idx);
constexpr auto y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4 =
make_naive_tensor_descriptor_packed(make_tuple(I1, P_M0, P_M1, P_M2, P_M3, P_M4));
constexpr auto y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4 =
lse_thread_desc_mb_m0_m1_m2_m3_m4; // reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto y_dot_ygrad_thread_copy_lds_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatGemmAcc,
FloatGemmAcc,
decltype(y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4),
decltype(y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4),
Sequence<1, m0, m1, m2, m3, m4>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1,
true /* ResetCoordAfterRun */>{
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4], // mperxdl
acc0_thread_origin[I5],
acc0_thread_origin[I6])};
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4.GetElementSpaceSize());
if constexpr(Deterministic)
{
block_sync_lds();
}
// Initialize dK&dV
kgrad_thread_buf.Clear();
vgrad_thread_buf.Clear();
// load k
gemm_tile_k_blockwise_copy.Run(k_grid_desc_k0_n_k1,
k_grid_buf,
GemmBlockwiseCopy::k_block_desc_k0_n_k1,
k_block_buf,
I0);
// load v
gemm_tile_v_blockwise_copy.Run(v_grid_desc_o0_n_o1,
v_grid_buf,
GemmBlockwiseCopy::v_block_desc_k0_n_k1,
v_block_buf,
I0);
// gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
do
{
auto m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm0_m_block_outer_index * MPerBlock);
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{
continue;
}
// load ygrad
gemm_tile_ygrad_blockwise_copy.Run(ygrad_grid_desc_o0_m_o1,
ygrad_grid_buf,
GemmBlockwiseCopy::ygrad_block_desc_k0_m_k1,
ygrad_block_buf,
I0);
block_sync_lds();
//
// calculate Y dot dY
//
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
d_grid_buf,
lse_thread_desc_mb_m0_m1_m2_m3_m4,
make_tuple(I0, I0, I0, I0, I0, I0),
y_dot_ygrad_thread_buf);
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
lse_grid_buf,
lse_thread_desc_mb_m0_m1_m2_m3_m4,
make_tuple(I0, I0, I0, I0, I0, I0),
lse_thread_buf);
// gemm dP
// dP = dY * V^T
pgrad_thread_buf.Clear();
pgrad_blockwise_gemm.Run(ygrad_block_buf, v_block_buf, pgrad_thread_buf);
// gemm S
// S = Q * K^T
s_slash_p_thread_buf.Clear();
gemm_tile_q_blockwise_copy.Run(q_grid_desc_k0_m_k1,
q_grid_buf,
GemmBlockwiseCopy::q_block_desc_k0_m_k1,
q_block_buf,
I0);
block_sync_lds();
s_blockwise_gemm.Run(q_block_buf, k_block_buf, s_slash_p_thread_buf);
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto M3 = c_block_lengths[I5];
constexpr auto M4 = c_block_lengths[I6];
constexpr auto N2 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN)
{
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
bool masked_flag = c0_matrix_mask.IsMaskedElement(m_global, n_global);
s_element_op(s_slash_p_thread_buf(i),
masked_flag ? -ck::NumericLimits<float>::Infinity()
: s_slash_p_thread_buf[i]);
});
}
else
{
static_for<0, s_slash_p_thread_buf.Size(), 1>{}([&](auto i) {
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
});
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global
if(p_z_grid)
{
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
else
{
ignore = z_grid_buf;
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
block_sync_lds(); // wait for gemm1 LDS read
// dS = P * (dP - Y_dot_dY)
auto& sgrad_thread_buf = pgrad_thread_buf;
constexpr auto pgrad_thread_tile_iterator =
pgrad_blockwise_gemm.MakeCThreadTileIterator();
constexpr auto pgrad_thread_idx_to_m_n_adaptor =
pgrad_blockwise_gemm.MakeCThreadIndexAdaptor8DTo2D();
static_for<0, pgrad_thread_tile_iterator.GetNumOfAccess(), 1>{}([&](auto i) {
constexpr auto pgrad_thread_idx = pgrad_thread_tile_iterator.GetIndex(i);
constexpr auto m =
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
// dS and P has same thread buf layout
bool undropped_flag = s_slash_p_thread_buf[i] >= 0;
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] *
(undropped_flag ? (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}])
: y_dot_ygrad_thread_buf[Number<m>{}]);
});
// gemm dV
// dV = P_drop^T * dY
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// main body
static_for<0, num_gemm1_k_block_inner_loop, 1>{}([&](auto i) {
vgrad_gemm_tile_p_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * i,
s_slash_p_thread_buf,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
gemm1_a_thread_buf);
vgrad_gemm_tile_ygrad_blockwise_copy.Run(
Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
ygrad_block_buf,
Gemm1::b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
gemm1_b_thread_buf);
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3, Gemm1::b_block_slice_copy_step);
block_sync_lds();
vgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_thread_buf, vgrad_thread_buf);
// block_sync_lds();
});
} // end gemm dV
// gemm dK
// dK = scalar * dS^T * Q
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// main body
static_for<0, num_gemm1_k_block_inner_loop, 1>{}([&](auto i) {
kgrad_gemm_tile_sgrad_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * i,
sgrad_thread_buf,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
gemm1_a_thread_buf);
kgrad_gemm_tile_q_blockwise_copy.Run(Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
q_block_buf,
Gemm1::b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
gemm1_b_thread_buf);
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3, Gemm1::b_block_slice_copy_step);
block_sync_lds();
kgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_thread_buf, kgrad_thread_buf);
// block_sync_lds();
});
} // end gemm dK
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]);
constexpr index_t num_gemm2_loop = NPerBlock / Gemm2Params::Sum_K;
static_assert(Gemm2::ASrcBlockSliceWindowIterator::GetNumOfAccess() == num_gemm2_loop,
"");
// TODO: tune gemm2 pipeline
// gemm dQ
// dQ = scalar * dS * K
qgrad_thread_buf.Clear();
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dQ
// load VGrad Gemm A
const auto sgrad_slice_idx =
Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
constexpr auto mwave_range = make_tuple(
sgrad_slice_idx[I2],
sgrad_slice_idx[I2] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I2));
constexpr auto nwave_range = make_tuple(
sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(
sgrad_slice_idx[I0], sgrad_slice_idx[I1], I0, I0, I0, I0, I0, I0),
sgrad_thread_buf,
Gemm2::a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2,
gemm2_a_block_buf);
}
// block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.Run(Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
k_block_buf,
Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
gemm2_b_thread_buf);
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3, Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before read
qgrad_blockwise_gemm.Run(gemm2_a_block_buf, gemm2_b_thread_buf, qgrad_thread_buf);
}); // end gemm dQ
// atomic_add dQ
qgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
qgrad_thread_buf,
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4,
qgrad_grid_buf);
// move slice window
gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_k0_m_k1,
GemmBlockwiseCopy::gemm_tile_q_block_slice_copy_step); // step M
gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1,
GemmBlockwiseCopy::gemm_tile_ygrad_block_slice_copy_step); // step M
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
Gemm1::b_block_reset_copy_step); // rewind M
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
Gemm2::b_block_reset_copy_step); // rewind K
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
Gemm1::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
Gemm1::b_block_reset_copy_step); // rewind M
qgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step M
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(-1, 0, 0, 0, 0, 0));
y_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
make_multi_index(-1, 0, 0, 0));
} while(0 < gemm0_m_block_outer_index--); // end j loop
// shuffle dK&dV and write
{
static_assert(Gemm1::GemmMRepeat % CShuffleMXdlPerWavePerShuffle == 0 &&
Gemm1::GemmNRepeat % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = Gemm1::GemmMWave;
constexpr index_t NWave = Gemm1::GemmNWave;
// TODO: hacky, fix it!
// thread desc same with kgrad_blockwise_gemm
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
vgrad_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
// block desc same with kgrad_blockwise_gemm
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
vgrad_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2)), // M2 = MPerXdl
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2, // N2 * N3 * N4 = NPerXdl
N3,
N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// index same with kgrad_blockwise_gemm
const auto c_thread_mtx_on_block =
vgrad_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto vgrad_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
SElementwiseOperation,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
tensor_operation::element_wise::Scale{rp_dropout}};
// shuffle: blockwise copy C from LDS to global
auto vgrad_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(vgrad_grid_desc_nblock_nperblock_oblock_operblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
vgrad_grid_desc_nblock_nperblock_oblock_operblock,
make_multi_index(block_work_idx_n, 0, block_work_idx[I1], 0),
c_element_op};
// shuffle: threadwise copy C from VGPR to LDS
auto kgrad_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
SElementwiseOperation,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
scale_rp_dropout};
// shuffle: blockwise copy C from LDS to global
auto kgrad_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(kgrad_grid_desc_nblock_nperblock_oblock_operblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
kgrad_grid_desc_nblock_nperblock_oblock_operblock,
make_multi_index(block_work_idx_n, 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = SpaceFillingCurve<
Sequence<Gemm1::GemmMRepeat, Gemm1::GemmNRepeat, 1, 1, 1, N2, 1, N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
1,
N2,
1,
N4>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, NPerBlock, 1, Gemm1NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
// dK
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
kgrad_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
kgrad_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
kgrad_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
kgrad_grid_desc_nblock_nperblock_oblock_operblock,
kgrad_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
kgrad_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
kgrad_grid_desc_nblock_nperblock_oblock_operblock, c_global_step);
}
});
// dV
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
vgrad_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
vgrad_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
vgrad_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
vgrad_grid_desc_nblock_nperblock_oblock_operblock,
vgrad_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
vgrad_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
vgrad_grid_desc_nblock_nperblock_oblock_operblock, c_global_step);
}
});
}
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace ck {
template <typename InputDataType,
typename OutputDataType,
typename ZDataType,
typename GemmDataType,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatLSE,
typename FloatD,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename SElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename QGridDesc_K0_M_K1,
typename KGridDesc_K0_N_K1,
typename KGridDesc_N_K,
typename ZGridDesc_M_N,
typename VGridDesc_N0_O_N1,
typename YGridDesc_M_O,
typename LSEGridDesc_M,
typename DGridDesc_M,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t B1K1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t BBlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1ThreadTransferSrcResetCoordinateAfterRun,
index_t B1BlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
bool PadN,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
static constexpr auto WaveSize = 64;
// K1 should be Number<...>
// Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
// Gemm1
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// C desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const ZGridDesc_M_N& z_grid_desc_m_n)
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, M3, M4, M5)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6, 7, 8>{}, Sequence<1, 3, 5, 9>{}));
}
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(WaveSize / MPerXdl, MPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
template <typename AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2>
__host__ __device__ static constexpr auto GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
const AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2& acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2)
{
// acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 to a_src_thread_desc_k0_m_k1
// m0_m1_m2_m3 -> k0
// n0_n1_n2 -> m
// m4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
const auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
const auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
const auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
const auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
const auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
const auto m3 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
const auto m4 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
const auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
return transform_tensor_descriptor(
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2, m3)),
make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2)),
make_pass_through_transform(m4)),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 7>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
__host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(B1K0, Number<Gemm1NPerBlock>{}, B1K1),
make_tuple(Number<Gemm1NPerBlock + B1BlockLdsExtraN>{} * B1K1, B1K1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = Gemm0NWaves;
constexpr index_t NWave = Gemm0MWaves;
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
template <typename Gemm2Param>
__host__ __device__ static constexpr auto GetA2BlockDescriptor_K0_M_K1()
{
return make_naive_tensor_descriptor(
make_tuple(Number<Gemm2Param::A_K0>{},
Number<Gemm2Param::Gemm2_M>{},
Number<Gemm2Param::A_K1>{}),
make_tuple(Number<Gemm2Param::Gemm2_M + Gemm2Param::A_LdsPad>{} *
Number<Gemm2Param::A_K1>{},
Number<Gemm2Param::A_K1>{},
I1));
}
template <typename Gemm2Param>
__host__ __device__ static constexpr auto GetB2BlockDescriptor_K0_N_K1()
{
return make_naive_tensor_descriptor(
make_tuple(Number<Gemm2Param::B_K0>{},
Number<Gemm2Param::Gemm2_N>{},
Number<Gemm2Param::B_K1>{}),
make_tuple(Number<Gemm2Param::Gemm2_N + Gemm2Param::B_LdsPad>{} *
Number<Gemm2Param::B_K1>{},
Number<Gemm2Param::B_K1>{},
I1));
}
__host__ __device__ static constexpr bool
CheckValidity(const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
const YGridDesc_M_O& y_grid_desc_m_o)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_n0_o_n1.GetLength(I1);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K)
{
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
if(!(M == y_grid_desc_m_o.GetLength(I0) && Gemm1N == y_grid_desc_m_o.GetLength(I1)))
{
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0))
{
return false;
}
// check gemm0 gridwise gemm pipeline
const auto num_gemm0_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
{
return false;
}
// check gemm1 gridwise gemm pipeline
if(!(NPerBlock % Gemm1KPerBlock == 0))
{
return false;
}
const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
__host__ __device__ static constexpr auto
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(const YGridDesc_M_O& y_grid_desc_m_o)
{
const auto M = y_grid_desc_m_o.GetLength(I0);
const auto O = y_grid_desc_m_o.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto OBlock = O / Gemm1NPerBlock;
const auto y_grid_desc_mblock_mperblock_oblock_operblock = transform_tensor_descriptor(
y_grid_desc_m_o,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(OBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return y_grid_desc_mblock_mperblock_oblock_operblock;
}
template <typename SrcBlockwiseGemm>
__host__ __device__ static constexpr auto
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4(const LSEGridDesc_M& lse_grid_desc_m)
{
const index_t M = lse_grid_desc_m.GetLength(I0);
const index_t MBlock = M / MPerBlock;
constexpr auto SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
SrcBlockwiseGemm::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// M0 MXdlPerWave, M1 MWave, M2 num_groups_per_blk, M3 num_input_blks, M4 group_size
const auto M0 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I0);
const auto M1 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I2);
const auto M2 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I4);
const auto M3 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I5);
const auto M4 = SrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2.GetLength(I6);
const auto lse_grid_desc_mb_m0_m1_m2_m3_m4 = transform_tensor_descriptor(
lse_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, M0, M1, M2, M3, M4))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}));
return lse_grid_desc_mb_m0_m1_m2_m3_m4;
}
__device__ static auto MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1)
{
const auto O0 = k_grid_desc_k0_n_k1.GetLength(I0);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto O1 = k_grid_desc_k0_n_k1.GetLength(I2);
const auto O = O0 * O1;
const auto NBlock = N / NPerBlock;
const auto OBlock = O / Gemm1NPerBlock;
const auto k_grid_desc_n_o = transform_tensor_descriptor(
k_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(N),
make_merge_transform_v3_division_mod(make_tuple(O0, O1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
k_grid_desc_n_o,
make_tuple(make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{})),
make_unmerge_transform(make_tuple(OBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
}
__device__ static auto MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1)
{
const auto N0 = v_grid_desc_n0_o_n1.GetLength(I0);
const auto O = v_grid_desc_n0_o_n1.GetLength(I1);
const auto N1 = v_grid_desc_n0_o_n1.GetLength(I2);
const auto N = N0 * N1;
const auto NBlock = N / NPerBlock;
const auto OBlock = O / Gemm1NPerBlock;
const auto v_grid_desc_n_o = transform_tensor_descriptor(
v_grid_desc_n0_o_n1,
make_tuple(make_pass_through_transform(O),
make_merge_transform_v3_division_mod(make_tuple(N0, N1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return transform_tensor_descriptor(
v_grid_desc_n_o,
make_tuple(make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{})),
make_unmerge_transform(make_tuple(OBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
}
__device__ static auto MakeQGradGridDesc_M_K(const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1)
{
const auto K_K0 = q_grid_desc_k0_m_k1.GetLength(I0);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto K_K1 = q_grid_desc_k0_m_k1.GetLength(I2);
return transform_tensor_descriptor(
q_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(M),
make_merge_transform_v3_division_mod(make_tuple(K_K0, K_K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const KGridDesc_N_K& k_grid_desc_n_k)
{
return BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, Gemm1NPerBlock, KGridDesc_N_K>(
k_grid_desc_n_k);
}
using YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock = remove_cvref_t<decltype(
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(YGridDesc_M_O{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(KGridDesc_N_K{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(ZGridDesc_M_N{}))>;
// S / dP Gemm (type 1 rcc)
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
template <typename GridDesc_K0_M_K1>
using ABlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
template <typename GridDesc_K0_N_K1>
using BBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>;
static constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
static constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
};
// dV / dK Gemm (type 2 rrr)
template <typename ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename ASrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2>
struct Gemm1
{
private:
static constexpr auto m0 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I0);
static constexpr auto n0 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I1);
static constexpr auto m1 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I2);
static constexpr auto n1 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I3);
static constexpr auto m2 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I4);
static constexpr auto m3 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I5);
static constexpr auto m4 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I6);
static constexpr auto n2 = ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I7);
// M2 num_groups_per_blk, M3 num_input_blks, M4 group_size
static constexpr auto M3 = ASrcBlockDesc_M0_N0_M1_N1_M2_M3_M4_N2{}.GetLength(I5);
public:
static constexpr auto AThreadSliceLength_K0 = Number<Gemm1KPerBlock / m4 / M3>{};
static constexpr auto AThreadSliceLength_M = Number<n0 * n1 * n2>{};
static constexpr auto AThreadSliceLength_K1 = Number<m4>{};
// A source matrix layout in AccVGPR
static constexpr auto a_src_thread_desc_k0_m_k1 =
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
ASrcThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2{});
// A matrix in VGPR memory, dst of AccVGPR-to-VGPR copy
static constexpr auto a_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed(
make_tuple(AThreadSliceLength_K0, AThreadSliceLength_M, AThreadSliceLength_K1));
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, 1, 1>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm1NXdlPerWave, Gemm1NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
static constexpr auto ASrcScalarPerVector = m4;
using AThreadSliceLengths_K0_M_K1 = decltype(a_thread_desc_k0_m_k1.GetLengths());
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy =
ThreadwiseTensorSliceTransfer_StaticToStatic<FloatGemmAcc,
GemmDataType,
decltype(a_src_thread_desc_k0_m_k1),
decltype(a_thread_desc_k0_m_k1),
ElementwiseOp,
AThreadSliceLengths_K0_M_K1,
Sequence<1, 0, 2>,
2,
ASrcScalarPerVector>;
template <typename GridDesc_K0_N_K1>
using BBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B1K0, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim,
2,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
1,
1,
B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord
NumGemmKPrefetchStage>;
// for a_block_slice_copy_step to be able to address static buffers, it MUST be a
// tuple-based container as well as containing ONLY integral constants
static constexpr auto a_block_slice_copy_step = make_tuple(AThreadSliceLength_K0, I0, I0);
static constexpr auto b_block_slice_copy_step =
make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack =
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_thread_desc_k0_m_k1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a_thread_desc_k0_m_k1)),
decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
NPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
MPerXdl,
NPerXdl,
NXdlPerWave,
Gemm1NXdlPerWave,
GemmKPack,
true, // TransposeC
GemmKPack, // AMmaKStride
GemmKPack * XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}
.K0PerXdlops /* BMmaKStride */>;
};
// dQ Gemm (type 3 crr)
// Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n
template <index_t Sum_K_ = NPerXdl * 2>
struct Gemm2Params_
{
static constexpr index_t Gemm2_M = MPerBlock;
static constexpr index_t Gemm2_K = NPerBlock;
static constexpr index_t Gemm2_N = Gemm1NPerBlock;
static constexpr index_t Sum_K = Sum_K_;
static constexpr index_t A_K1 = 8; // P will be row-major
static constexpr index_t A_K0 = Sum_K / A_K1;
static constexpr index_t A_LdsPad = 0; // how many multiples of K1 per M * K1 elements
static constexpr index_t B_K1 = B1K1; // dY assumed row-major, typically =2 for fp16
static constexpr index_t B_K0 = Sum_K / B_K1;
static constexpr index_t B_LdsPad = 0; // how many multiples of K1 per N * K1 elements
static_assert(Sum_K % NPerXdl == 0, "");
static constexpr index_t BSrcVectorDim = 1; // Gemm2_N dimension
static constexpr index_t BSrcScalarPerVector = 4;
static constexpr index_t GemmNWave = Gemm2_N / Gemm2NXdlPerWave / NPerXdl;
static constexpr index_t GemmMWave = BlockSize / get_warp_size() / GemmNWave;
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave;
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl;
static constexpr index_t GemmKPack =
math::max(math::lcm(A_K1, B_K1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using BBlockSliceLengths = Sequence<B_K0, Gemm2_N, B_K1>;
using BThreadClusterLengths =
Sequence<BlockSize / (Gemm2_N / BSrcScalarPerVector), Gemm2_N / BSrcScalarPerVector, 1>;
using BThreadClusterArrangeOrder = Sequence<0, 2, 1>;
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1_M2_K2()
{
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr index_t k = Gemm2Params::Sum_K - 1;
constexpr index_t k2 = k % NPerXdl;
constexpr index_t k1 = k / NPerXdl % Gemm0NWaves;
constexpr index_t k0 = k / NPerXdl / Gemm0NWaves % NXdlPerWave;
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr index_t m = Gemm2Params::Gemm2_M - 1;
constexpr index_t m2 = m % MPerXdl;
constexpr index_t m1 = m / MPerXdl % Gemm0MWaves;
constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave;
// assume 256 decomposed into 2 x 4 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
// 1d idx (256 - 1) -> 3d idx 1, 3, 31 -> 3d dim 2 x 4 x 32
return Sequence<m0, k0, m1, k1, m2, k2>{} + Sequence<1, 1, 1, 1, 1, 1>{};
}
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1()
{
return generate_sequence_v2(
[](auto I) { return GetABlockSliceLengths_M0_K0_M1_K1_M2_K2().At(I); },
Number<4>{});
}
using ABlockSliceLengths_M0_K0_M1_K1 =
decltype(GetABlockSliceLengths_M0_K0_M1_K1()); //(2, 1, 1, 2) //(4, 1, 1, 2)
};
using Gemm2Params = Gemm2Params_<>; // tune later
// dQ Gemm (type 3 crr)
template <typename Gemm2Params, typename ASrcBlockwiseGemm>
struct Gemm2
{
private:
static constexpr auto a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
ASrcBlockwiseGemm::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
static constexpr auto M0 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); // repeat
static constexpr auto N0 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
static constexpr auto M1 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); // wave
static constexpr auto N1 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
static constexpr auto M2 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); // xdl
static constexpr auto M3 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
static constexpr auto M4 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
static constexpr auto N2 = a_src_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
public:
// A source matrix layout in VGPR, src of VGPR-to-LDS copy
static constexpr auto a_src_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
ASrcBlockwiseGemm::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_k0_n_k1 = GetB2BlockDescriptor_K0_N_K1<Gemm2Params>();
__host__ __device__ static constexpr auto MakeABlockDesc_M0_K0_M1_K1_M2_M3_M4_K2()
{
const auto K0_ = a_block_desc_k0_m_k1.GetLength(I0);
const auto M_ = a_block_desc_k0_m_k1.GetLength(I1);
const auto K1_ = a_block_desc_k0_m_k1.GetLength(I2);
const auto a_block_desc_m_k = transform_tensor_descriptor(
a_block_desc_k0_m_k1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0_, K1_)),
make_pass_through_transform(M_)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
return transform_tensor_descriptor(
a_block_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(I1, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(I1, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
}
// Note: we will perform sub-workgroup VGPR-to-LDS copy to save LDS space, therefore the
// destination coordinate can overlap between wavefronts in a workgroup as seen in the mod
// operation before returning the values
__host__ __device__ static auto MakeAThreadOriginOnBlock_M0_K0_M1_K1_M2_M3_M4_K2()
{
const auto a_thread_origin_on_block_idx =
ASrcBlockwiseGemm::CalculateCThreadOriginDataIndex8D(I0, I0, I0, I0);
constexpr auto a_block_slice_lengths_m0_k0_m1_k1 =
typename Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1{}; // mrepeat, nrepeat,
// mwaves, nwaves,
return make_tuple(
a_thread_origin_on_block_idx[I0], // mrepeat
a_thread_origin_on_block_idx[I1], // nrepeat
a_thread_origin_on_block_idx[I2] % a_block_slice_lengths_m0_k0_m1_k1[I2], // mwave
a_thread_origin_on_block_idx[I3] % a_block_slice_lengths_m0_k0_m1_k1[I3], // nwave
a_thread_origin_on_block_idx[I4], // xdlops
a_thread_origin_on_block_idx[I5],
a_thread_origin_on_block_idx[I6],
a_thread_origin_on_block_idx[I7]);
}
static constexpr auto a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2 =
MakeABlockDesc_M0_K0_M1_K1_M2_M3_M4_K2();
using ASrcBlockSliceWindowIterator =
SpaceFillingCurve<Sequence<M0, N0, M1, N1>,
Sequence<0, 1, 2, 3>,
typename Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1,
false>;
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
GemmDataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2),
ElementwiseOp,
Sequence<Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I0), // ThreadSliceLengths
Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I1),
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>;
template <typename GridDesc_K0_N_K1>
using BBlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
typename Gemm2Params::BBlockSliceLengths,
typename Gemm2Params::BThreadClusterLengths,
typename Gemm2Params::BThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(b_block_desc_k0_n_k1),
typename Gemm2Params::BThreadClusterArrangeOrder, // access order == thread order
Sequence<1, 0, 2>,
Gemm2Params::BSrcVectorDim,
2, // DstVectorDim
Gemm2Params::BSrcScalarPerVector,
Gemm2Params::B_K1,
1,
1,
true,
true,
1>;
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
MPerXdl,
NPerXdl,
Gemm2Params::GemmMRepeat,
Gemm2Params::GemmNRepeat,
Gemm2Params::GemmKPack,
true>; // TranspossC
static constexpr auto b_block_slice_copy_step = make_multi_index(Gemm2Params::B_K0, 0, 0);
static constexpr auto c_block_slice_copy_step =
make_multi_index(-Gemm2Params::GemmMRepeat, 0, 0, 0, 0, 0, 0, 0);
static constexpr auto b_block_reset_copy_step =
make_multi_index(-NPerBlock / Gemm2Params::B_K1, 0, 0);
template <typename CGradDesc_M_N>
__host__ __device__ static auto
MakeCGridDesc_M0_N0_M1_N1_M2_N2_N3_N4(const CGradDesc_M_N& c_grid_desc_m_n)
{
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(I1, Gemm2Params::GemmMWave, MPerXdl)),
make_unmerge_transform(make_tuple(I1, Gemm2Params::GemmNWave, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
const auto c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
BlockwiseGemm{}.xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(
c_grid_desc_m0_n0_m1_n1_m2_n2);
return c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4;
}
static constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
BlockwiseGemm::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
__host__ __device__ static auto GetCThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4()
{
return to_multi_index(BlockwiseGemm::CalculateCThreadOriginDataIndex8D(I0, I0, I0, I0));
}
template <typename CGridDesc_M0_N0_M1_N1_M2_N2_N3_N4,
typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
OutputDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
CGridDesc_M0_N0_M1_N1_M2_N2_N3_N4,
ElementwiseOp, // CElementwiseOperation
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLengths()), // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim
2, // ScalarPerVector
InMemoryDataOperationEnum::AtomicAdd, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
};
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVector = 16 / sizeof(InputDataType);
static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVector>{};
static constexpr auto ThreadSliceLength_M =
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatGemmAcc,
ThreadSliceLength_M * ThreadSliceLength_O,
true>;
using DstBufType =
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, ThreadSliceLength_M, true>;
};
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>;
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct PGradGemmTile_M_N_O
{
// TODO:
// Make all input tensors 2D and transform them into appropriate 3D form in kernel to make
// things more concise
template <typename YGradGridDesc_M0_O_M1_>
__device__ static auto
MakeYGradGridDesc_O0_M_O1(const YGradGridDesc_M0_O_M1_& ygrad_grid_desc_m0_o_m1)
{
const auto M0 = ygrad_grid_desc_m0_o_m1.GetLength(I0);
const auto O = ygrad_grid_desc_m0_o_m1.GetLength(I1);
const auto M1 = ygrad_grid_desc_m0_o_m1.GetLength(I2);
constexpr auto Y_O1 = AK1;
const auto Y_O0 = O / Y_O1;
const auto ygrad_grid_desc_o0_m_o1 = transform_tensor_descriptor(
ygrad_grid_desc_m0_o_m1,
make_tuple(make_unmerge_transform(make_tuple(Y_O0, Y_O1)),
make_merge_transform_v3_division_mod(make_tuple(M0, M1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return ygrad_grid_desc_o0_m_o1;
}
template <typename VGridDesc_N0_O_N1_>
__device__ static auto MakeVGridDesc_O0_N_O1(const VGridDesc_N0_O_N1_& v_grid_desc_n0_o_n1)
{
const auto N0 = v_grid_desc_n0_o_n1.GetLength(I0);
const auto O = v_grid_desc_n0_o_n1.GetLength(I1);
const auto N1 = v_grid_desc_n0_o_n1.GetLength(I2);
constexpr auto V_O1 = BK1;
const auto V_O0 = O / V_O1;
const auto v_grid_desc_o0_n_o1 = transform_tensor_descriptor(
v_grid_desc_n0_o_n1,
make_tuple(make_unmerge_transform(make_tuple(V_O0, V_O1)),
make_merge_transform_v3_division_mod(make_tuple(N0, N1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return v_grid_desc_o0_n_o1;
}
};
// QGrad Gemm has the same layout as Y = P * V Gemm (A in acc B row-major)
struct QGradGemmTile_M_K_N
{
template <typename KGridDesc_K0_N_K1_>
__device__ static auto MakeKGridDesc_N0_K_N1(const KGridDesc_K0_N_K1_& k_grid_desc_k0_n_k1)
{
const auto K_K0 = k_grid_desc_k0_n_k1.GetLength(I0);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K_K1 = k_grid_desc_k0_n_k1.GetLength(I2);
constexpr auto K_N1 = B1K1;
const auto K_N0 = N / K_N1;
const auto k_grid_desc_n0_k_n1 = transform_tensor_descriptor(
k_grid_desc_k0_n_k1,
make_tuple(make_unmerge_transform(make_tuple(K_N0, K_N1)),
make_merge_transform_v3_division_mod(make_tuple(K_K0, K_K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return k_grid_desc_n0_k_n1;
}
};
struct KGradGemmTile_N_K_M
{
// B position
template <typename QGridDesc_K0_M_K1_>
__device__ static auto MakeQGridDesc_M0_K_M1(const QGridDesc_K0_M_K1_& q_grid_desc_k0_m_k1)
{
const auto Q_K0 = q_grid_desc_k0_m_k1.GetLength(I0);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto Q_K1 = q_grid_desc_k0_m_k1.GetLength(I2);
constexpr auto Q_M1 = B1K1;
const auto Q_M0 = M / Q_M1;
const auto q_grid_desc_m0_k_m1 = transform_tensor_descriptor(
q_grid_desc_k0_m_k1,
make_tuple(make_unmerge_transform(make_tuple(Q_M0, Q_M1)),
make_merge_transform_v3_division_mod(make_tuple(Q_K0, Q_K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return q_grid_desc_m0_k_m1;
}
};
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto a2_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto b2_block_desc_k0_n_k1 = GetB2BlockDescriptor_K0_N_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a2_block_space_size_aligned = math::integer_least_multiple(
a2_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b2_block_space_size_aligned = math::integer_least_multiple(
b2_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
static constexpr auto a2_block_space_offset = 0;
static constexpr auto b2_block_space_offset = a2_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm2_bytes_end = (SharedMemTrait::a2_block_space_size_aligned +
SharedMemTrait::b2_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
gemm2_bytes_end,
softmax_bytes_end,
c_block_bytes_end);
}
template <bool HasMainKBlockLoop,
typename Block2CTileMap,
typename C0MatrixMask,
typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const InputDataType* __restrict__ p_q_grid,
const InputDataType* __restrict__ p_k_grid,
ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_v_grid,
const InputDataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid,
const FloatD* __restrict__ p_d_grid,
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const SElementwiseOperation& s_element_op,
const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
const LSEGridDesc_M& lse_grid_desc_m,
const DGridDesc_M& d_grid_desc_m,
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
const float p_drop,
ck::philox& ph,
const index_t z_random_matrix_offset,
const index_t raw_n_padded,
const index_t block_idx_n)
{
ignore = p_y_grid;
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits =
__builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout);
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto k_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_k_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
const auto v_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_v_grid, v_grid_desc_n0_o_n1.GetElementSpaceSize());
const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
const auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_m.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, v_grid_desc_n0_o_n1.GetElementSpaceSize());
auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_kgrad_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
// divide block work by [N, K]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t block_work_idx_n = Deterministic ? block_idx_n : block_work_idx[I0];
// HACK: this force n_block_data_idx_on_grid into SGPR
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx_n * NPerBlock);
const index_t num_gemm0_m_block_outer_loop = q_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = MPerBlock / Gemm1KPerBlock;
// 6 GEMM operations are categorized into 3 buckets. SizeK == SizeO == head_dim
// S_MNK / dP_MNO Gemm (Gemm0 rcc)
// dV_NOM / dK_NKM Gemm (Gemm1 rrr)
// Y_MON / dQ_MKN Gemm (Gemm2 crr)
//
// set up S / dP Gemm (type 1 rcc)
//
// Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto gemm0_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
Gemm0::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// Gemm0: gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const auto gemm0_gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer,
NumGemmKPrefetchStage,
LoopScheduler::Default>();
// S: A matrix blockwise copy
auto s_gemm_tile_q_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>(
q_grid_desc_k0_m_k1,
make_multi_index(0,
MPerBlock * (num_gemm0_m_block_outer_loop - 1),
0), // will loop over GemmM dimension
a_element_op,
Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// S: B matrix blockwise copy
auto s_gemm_tile_k_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(k_grid_desc_k0_n_k1)>(
k_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
Gemm0::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// S: blockwise gemm
auto s_blockwise_gemm = typename Gemm0::BlockwiseGemm{}; // TransposeC
auto s_slash_p_thread_buf = s_blockwise_gemm.GetCThreadBuffer();
const auto s_gemm_tile_q_block_reset_copy_step =
make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), -MPerBlock, 0);
const auto s_gemm_tile_k_block_reset_copy_step =
make_multi_index(-k_grid_desc_k0_n_k1.GetLength(I0), 0, 0);
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2)) / KPerBlock);
// dP: transform input and output tensor descriptors
const auto ygrad_grid_desc_o0_m_o1 =
PGradGemmTile_M_N_O::MakeYGradGridDesc_O0_M_O1(ygrad_grid_desc_m0_o_m1);
const auto v_grid_desc_o0_n_o1 =
PGradGemmTile_M_N_O::MakeVGridDesc_O0_N_O1(v_grid_desc_n0_o_n1);
// dP: A matrix blockwise copy
auto pgrad_gemm_tile_ygrad_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(ygrad_grid_desc_o0_m_o1)>(
ygrad_grid_desc_o0_m_o1,
make_multi_index(0,
MPerBlock * (num_gemm0_m_block_outer_loop - 1),
0), // will loop over GemmM dimension
tensor_operation::element_wise::PassThrough{},
Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dP: B matrix blockwise copy
auto pgrad_gemm_tile_v_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(v_grid_desc_o0_n_o1)>(
v_grid_desc_o0_n_o1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
tensor_operation::element_wise::PassThrough{},
Gemm0::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dP: blockwise gemm
// we need separate blockwise gemm object because we need separate thread buffer
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
const auto pgrad_gemm_tile_ygrad_block_reset_copy_step =
make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), -MPerBlock, 0);
const auto pgrad_gemm_tile_v_block_reset_copy_step =
make_multi_index(-v_grid_desc_o0_n_o1.GetLength(I0), 0, 0);
const index_t num_o_block_main_loop = __builtin_amdgcn_readfirstlane(
(ygrad_grid_desc_o0_m_o1.GetLength(I0) * ygrad_grid_desc_o0_m_o1.GetLength(I2)) /
KPerBlock);
//
// set up dV / dK Gemm (type 2 rrr)
//
using Gemm1 =
Gemm1<decltype(s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()),
decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2())>;
// Gemm1: VGPR allocation for A and LDS allocation for B
auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize());
auto gemm1_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
Gemm1::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// dV: A matrix blockwise copy
auto vgrad_gemm_tile_p_blockwise_copy =
typename Gemm1::template ABlockwiseCopy<tensor_operation::element_wise::Relu>{
tensor_operation::element_wise::Relu{}}; // relu(P-dropped)
// dV: B matrix blockwise copy
auto vgrad_gemm_tile_ygrad_blockwise_copy =
typename Gemm1::template BBlockwiseCopy<decltype(ygrad_grid_desc_m0_o_m1)>(
ygrad_grid_desc_m0_o_m1,
make_multi_index(MPerBlock / B1K1 * (num_gemm0_m_block_outer_loop - 1), 0, 0),
b1_element_op,
Gemm1::b_block_desc_bk0_n_bk1, // there n actually is k, k is N, so name can be
// b_block_desc_bn0_k_bn1
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
const auto vgrad_gemm_tile_ygrad_block_next_copy_step =
make_multi_index(-2 * MPerBlock / B1K1, 0, 0);
// dV: blockwise gemm
auto vgrad_blockwise_gemm =
typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0)}; // A_origin
auto vgrad_thread_buf = vgrad_blockwise_gemm.GetCThreadBuffer();
// dV: transform input and output tensor descriptors
auto vgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(v_grid_desc_n0_o_n1);
// dK: transform input and output tensor descriptors
const auto q_grid_desc_m0_k_m1 =
KGradGemmTile_N_K_M::MakeQGridDesc_M0_K_M1(q_grid_desc_k0_m_k1);
// dK: A matrix blockwise copy
auto kgrad_gemm_tile_sgrad_blockwise_copy =
typename Gemm1::template ABlockwiseCopy<tensor_operation::element_wise::PassThrough>{
tensor_operation::element_wise::PassThrough{}};
// dK: B matrix blockwise copy
auto kgrad_gemm_tile_q_blockwise_copy =
typename Gemm1::template BBlockwiseCopy<decltype(q_grid_desc_m0_k_m1)>(
q_grid_desc_m0_k_m1,
make_multi_index(MPerBlock / B1K1 * (num_gemm0_m_block_outer_loop - 1), 0, 0),
b1_element_op,
Gemm1::b_block_desc_bk0_n_bk1, // there n actually is k, k is N, so name can be
// b_block_desc_bn0_k_bn1
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
const auto kgrad_gemm_tile_q_block_next_copy_step =
make_multi_index(-2 * MPerBlock / B1K1, 0, 0);
// dK: blockwise gemm
auto kgrad_blockwise_gemm =
typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0)}; // A_origin
auto kgrad_thread_buf = kgrad_blockwise_gemm.GetCThreadBuffer();
// dK: transform input and output tensor descriptors
auto kgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(k_grid_desc_k0_n_k1);
//
// set up dQ Gemm (type 3 crr)
//
using Gemm2 = Gemm2<Gemm2Params, decltype(pgrad_blockwise_gemm)>;
// Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a2_block_space_offset,
Gemm2::a_block_desc_k0_m_k1.GetElementSpaceSize());
auto gemm2_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::b2_block_space_offset,
Gemm2::b_block_desc_k0_n_k1.GetElementSpaceSize());
// // dQ: transform input and output tensor descriptors
// const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
// Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o);
// dQ: transform input and output tensor descriptors
const auto k_grid_desc_n0_k_n1 =
QGradGemmTile_M_K_N::MakeKGridDesc_N0_K_N1(k_grid_desc_k0_n_k1);
// dQ: A matrix VGPR-to-LDS blockwise copy
auto qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds =
typename Gemm2::template ABlockwiseCopy<tensor_operation::element_wise::PassThrough>{
Gemm2::a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2,
Gemm2::MakeAThreadOriginOnBlock_M0_K0_M1_K1_M2_M3_M4_K2(),
tensor_operation::element_wise::PassThrough{}};
// dQ: B matrix global-to-LDS blockwise copy
auto qgrad_gemm_tile_k_blockwise_copy =
typename Gemm2::template BBlockwiseCopy<decltype(k_grid_desc_n0_k_n1)>(
k_grid_desc_n0_k_n1,
make_multi_index(n_block_data_idx_on_grid / Gemm2Params::B_K1, 0, 0),
tensor_operation::element_wise::PassThrough{},
Gemm2::b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dQ: blockwise gemm
auto qgrad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer();
// dQ: transform output tensor descriptors
const auto qgrad_grid_desc_m_k = MakeQGradGridDesc_M_K(q_grid_desc_k0_m_k1);
const auto qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4 =
Gemm2::MakeCGridDesc_M0_N0_M1_N1_M2_N2_N3_N4(qgrad_grid_desc_m_k);
// dQ: C VGPR-to-global copy
const auto qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4 =
Gemm2::GetCThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4() +
make_multi_index((num_gemm0_m_block_outer_loop - 1) * Gemm2Params::GemmMRepeat,
I0,
I0,
I0,
I0,
I0,
I0,
I0);
auto qgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<
decltype(qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4),
decltype(scale_rp_dropout)>(qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4,
qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4,
scale_rp_dropout);
//
// Blockwise softmax
//
// get acc0 8D thread cluster
constexpr auto thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2 =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths() /
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I0);
constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I1);
constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I2);
constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I3);
constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I4);
constexpr auto tm3 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I5);
constexpr auto tm4 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I6);
constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_m3_m4_n2.At(I7);
// get acc0 thread map
constexpr auto n0_m_n1_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(tn0 * tn1, tn2)),
make_pass_through_transform(I1)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto threadid_to_n0_m_n1_adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_merge_transform(make_tuple(tn0 * tn1, tm0 * tm1 * tm2 * tm3 * tm4, tn2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto threadid_to_m_n_thread_cluster_adaptor =
chain_tensor_adaptors(n0_m_n1_to_m_n_adaptor, threadid_to_n0_m_n1_adaptor);
// get acc0 2D thread cluster & 2D thread slice
constexpr auto thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto m0 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
constexpr auto n0 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
constexpr auto m1 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
constexpr auto n1 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
constexpr auto m2 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
constexpr auto m3 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
constexpr auto m4 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto n2 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(tm0 * tm1 * tm2 * tm3 * tm4, tn0 * tn1 * tn2));
constexpr auto thread_slice_desc_m_n =
make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2 * m3 * m4, n0 * n1 * n2));
auto blockwise_softmax = BlockwiseSoftmax<BlockSize,
FloatGemmAcc,
decltype(threadid_to_m_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, rp_dropout};
auto lse_grid_desc_mb_m0_m1_m2_m3_m4 =
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4<decltype(s_blockwise_gemm)>(lse_grid_desc_m);
constexpr auto lse_thread_desc_mb_m0_m1_m2_m3_m4 =
make_naive_tensor_descriptor_packed(make_tuple(I1, m0, m1, m2, m3, m4));
auto lse_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatLSE>(
lse_thread_desc_mb_m0_m1_m2_m3_m4.GetElementSpaceSize());
auto acc0_thread_origin = s_blockwise_gemm.CalculateCThreadOriginDataIndex8D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
auto lse_thread_copy_global_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatLSE,
FloatLSE,
decltype(lse_grid_desc_mb_m0_m1_m2_m3_m4),
decltype(lse_thread_desc_mb_m0_m1_m2_m3_m4),
Sequence<1, m0, m1, m2, m3, m4>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1,
true /* ResetCoordAfterRun */>{
lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(num_gemm0_m_block_outer_loop - 1, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4], // mperxdl
acc0_thread_origin[I5],
acc0_thread_origin[I6])};
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MGroupNum
m3, // MInputNum
m4, // RegisterNum
n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr,
ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
m3, // NGroupNum
m4, // NInputNum
n2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(num_gemm0_m_block_outer_loop - 1, // MBlockId
block_work_idx_n, // NBlockId
0, // MRepeat
0, // NRepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
0, // MPerXdl
wave_m_n_id[I0], //
0, //
wave_m_n_id[I1]), // NPerXdl
tensor_operation::element_wise::PassThrough{}};
//
// set up Y dot dY
//
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr auto p_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto P_M0 = p_block_lengths[I0]; // repeats
constexpr auto P_M1 = p_block_lengths[I2]; // waves
constexpr auto P_M2 = p_block_lengths[I4]; // xdl
constexpr auto P_M3 = p_block_lengths[I5];
constexpr auto P_M4 = p_block_lengths[I6];
constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple(
I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc =
make_cluster_descriptor(Sequence<I1,
YDotYGrad_M_O::ThreadClusterLength_M,
I1,
YDotYGrad_M_O::ThreadClusterLength_O>{},
Sequence<0, 1, 2, 3>{});
const auto y_thread_cluster_idx =
y_thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
const auto y_thread_data_on_grid_idx =
make_multi_index(num_gemm0_m_block_outer_loop - 1, I0, I0, I0) +
y_thread_data_on_block_idx;
// performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
InputDataType,
FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>,
3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */,
false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx);
constexpr auto y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4 =
make_naive_tensor_descriptor_packed(make_tuple(I1, P_M0, P_M1, P_M2, P_M3, P_M4));
constexpr auto y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4 =
lse_thread_desc_mb_m0_m1_m2_m3_m4; // reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto y_dot_ygrad_thread_copy_lds_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatGemmAcc,
FloatGemmAcc,
decltype(y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4),
decltype(y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4),
Sequence<1, m0, m1, m2, m3, m4>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1,
true /* ResetCoordAfterRun */>{
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4], // mperxdl
acc0_thread_origin[I5],
acc0_thread_origin[I6])};
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4.GetElementSpaceSize());
if constexpr(Deterministic)
{
block_sync_lds();
}
// Initialize dK&dV
kgrad_thread_buf.Clear();
vgrad_thread_buf.Clear();
// gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
do
{
auto m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm0_m_block_outer_index * MPerBlock);
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{
continue;
}
//
// calculate Y dot dY
//
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
d_grid_buf,
lse_thread_desc_mb_m0_m1_m2_m3_m4,
make_tuple(I0, I0, I0, I0, I0, I0),
y_dot_ygrad_thread_buf);
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
lse_grid_buf,
lse_thread_desc_mb_m0_m1_m2_m3_m4,
make_tuple(I0, I0, I0, I0, I0, I0),
lse_thread_buf);
// S = Q * K^T
gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
q_grid_desc_k0_m_k1,
Gemm0::a_block_desc_ak0_m_ak1,
s_gemm_tile_q_blockwise_copy,
q_grid_buf,
gemm0_a_block_buf,
Gemm0::a_block_slice_copy_step,
k_grid_desc_k0_n_k1,
Gemm0::b_block_desc_bk0_n_bk1,
s_gemm_tile_k_blockwise_copy,
k_grid_buf,
gemm0_b_block_buf,
Gemm0::b_block_slice_copy_step,
s_blockwise_gemm,
s_slash_p_thread_buf,
num_k_block_main_loop);
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto M3 = c_block_lengths[I5];
constexpr auto M4 = c_block_lengths[I6];
constexpr auto N2 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN)
{
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
bool masked_flag = c0_matrix_mask.IsMaskedElement(m_global, n_global);
s_element_op(s_slash_p_thread_buf(i),
masked_flag ? -ck::NumericLimits<float>::Infinity()
: s_slash_p_thread_buf[i]);
});
}
else
{
static_for<0, s_slash_p_thread_buf.Size(), 1>{}([&](auto i) {
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
});
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global
if(p_z_grid)
{
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
else
{
ignore = z_grid_buf;
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
block_sync_lds(); // wait for gemm1 LDS read
// gemm dV
// dV = P_drop^T * dY
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
ygrad_grid_buf);
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1, Gemm1::b_block_slice_copy_step);
block_sync_lds(); // wait for previous LDS read
vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm1::b_block_desc_bk0_n_bk1,
gemm1_b_block_buf);
// main body
if constexpr(num_gemm1_k_block_inner_loop > 1)
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
vgrad_gemm_tile_p_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * i,
s_slash_p_thread_buf,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
gemm1_a_thread_buf);
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
ygrad_grid_buf);
block_sync_lds();
vgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_block_buf, vgrad_thread_buf);
block_sync_lds();
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1, Gemm1::b_block_slice_copy_step);
vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm1::b_block_desc_bk0_n_bk1,
gemm1_b_block_buf);
});
}
// tail
{
vgrad_gemm_tile_p_blockwise_copy.Run(
Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * Number<num_gemm1_k_block_inner_loop - 1>{},
s_slash_p_thread_buf,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
gemm1_a_thread_buf);
block_sync_lds();
vgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_block_buf, vgrad_thread_buf);
}
} // end gemm dV
// gemm dP
block_sync_lds();
// dP = dY * V^T
// assume size K == size O so HasMainKBlockLoop is the same
gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
ygrad_grid_desc_o0_m_o1,
Gemm0::a_block_desc_ak0_m_ak1, // reuse
pgrad_gemm_tile_ygrad_blockwise_copy,
ygrad_grid_buf,
gemm0_a_block_buf, // reuse
Gemm0::a_block_slice_copy_step, // reuse
v_grid_desc_o0_n_o1,
Gemm0::b_block_desc_bk0_n_bk1, // reuse
pgrad_gemm_tile_v_blockwise_copy,
v_grid_buf,
gemm0_b_block_buf, // reuse
Gemm0::b_block_slice_copy_step, // reuse
pgrad_blockwise_gemm,
pgrad_thread_buf,
num_o_block_main_loop);
// dS = P * (dP - Y_dot_dY)
auto& sgrad_thread_buf = pgrad_thread_buf;
constexpr auto pgrad_thread_tile_iterator =
pgrad_blockwise_gemm.MakeCThreadTileIterator();
constexpr auto pgrad_thread_idx_to_m_n_adaptor =
pgrad_blockwise_gemm.MakeCThreadIndexAdaptor8DTo2D();
static_for<0, pgrad_thread_tile_iterator.GetNumOfAccess(), 1>{}([&](auto i) {
constexpr auto pgrad_thread_idx = pgrad_thread_tile_iterator.GetIndex(i);
constexpr auto m =
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
// dS and P has same thread buf layout
bool undropped_flag = s_slash_p_thread_buf[i] >= 0;
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] *
(undropped_flag ? (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}])
: y_dot_ygrad_thread_buf[Number<m>{}]);
});
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]);
constexpr index_t num_gemm2_loop = NPerBlock / Gemm2Params::Sum_K;
static_assert(Gemm2::ASrcBlockSliceWindowIterator::GetNumOfAccess() == num_gemm2_loop,
"");
// TODO: tune gemm2 pipeline
// gemm dQ
// dQ = scalar * dS * K
qgrad_thread_buf.Clear();
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dQ
// load QGrad Gemm B
qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf);
// load QGrad Gemm A
const auto sgrad_slice_idx =
Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
constexpr auto mwave_range = make_tuple(
sgrad_slice_idx[I2],
sgrad_slice_idx[I2] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I2));
constexpr auto nwave_range = make_tuple(
sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(
sgrad_slice_idx[I0], sgrad_slice_idx[I1], I0, I0, I0, I0, I0, I0),
sgrad_thread_buf,
Gemm2::a_block_desc_m0_k0_m1_k1_m2_m3_m4_k2,
gemm2_a_block_buf);
}
// k slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// sgrad slice window is moved by loop index
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1,
Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm2::b_block_desc_k0_n_k1,
gemm2_b_block_buf);
block_sync_lds(); // sync before read
qgrad_blockwise_gemm.Run(gemm2_a_block_buf, gemm2_b_block_buf, qgrad_thread_buf);
}); // end gemm dQ
// atomic_add dQ
qgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
qgrad_thread_buf,
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4,
qgrad_grid_buf);
// gemm dK
// dK = scalar * dS^T * Q
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf);
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
Gemm1::b_block_slice_copy_step);
block_sync_lds(); // wait for previous LDS read
kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm1::b_block_desc_bk0_n_bk1,
gemm1_b_block_buf);
// main body
if constexpr(num_gemm1_k_block_inner_loop > 1)
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
kgrad_gemm_tile_sgrad_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * i,
sgrad_thread_buf,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
gemm1_a_thread_buf);
kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf);
block_sync_lds();
kgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_block_buf, kgrad_thread_buf);
block_sync_lds();
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_m0_k_m1, Gemm1::b_block_slice_copy_step);
kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm1::b_block_desc_bk0_n_bk1,
gemm1_b_block_buf);
});
}
// tail
{
kgrad_gemm_tile_sgrad_blockwise_copy.Run(
Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * Number<num_gemm1_k_block_inner_loop - 1>{},
sgrad_thread_buf,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
gemm1_a_thread_buf);
block_sync_lds();
kgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_block_buf, kgrad_thread_buf);
}
} // end gemm dK
// move slice window
s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_k0_m_k1,
s_gemm_tile_q_block_reset_copy_step); // rewind K and step M
s_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
k_grid_desc_k0_n_k1,
s_gemm_tile_k_block_reset_copy_step); // rewind K
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1,
pgrad_gemm_tile_ygrad_block_reset_copy_step); // rewind O and step M
pgrad_gemm_tile_v_blockwise_copy.MoveSrcSliceWindow(
v_grid_desc_o0_n_o1,
pgrad_gemm_tile_v_block_reset_copy_step); // rewind O
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
k_grid_desc_n0_k_n1,
Gemm2::b_block_reset_copy_step); // rewind N
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_m0_k_m1, kgrad_gemm_tile_q_block_next_copy_step); // step M
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1, vgrad_gemm_tile_ygrad_block_next_copy_step); // step M
qgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step M
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(-1, 0, 0, 0, 0, 0));
yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
make_multi_index(-1, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
} while(0 < gemm0_m_block_outer_index--); // end j loop
// shuffle dK&dV and write
{
static_assert(NXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = Gemm0NWaves;
constexpr index_t NWave = Gemm0MWaves;
// TODO: hacky, fix it!
// thread desc same with kgrad_blockwise_gemm
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
vgrad_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
// block desc same with kgrad_blockwise_gemm
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
vgrad_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2)), // M2 = MPerXdl
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2, // N2 * N3 * N4 = NPerXdl
N3,
N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// index same with kgrad_blockwise_gemm
const auto c_thread_mtx_on_block =
vgrad_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto vgrad_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
SElementwiseOperation,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
tensor_operation::element_wise::Scale{rp_dropout}};
// shuffle: blockwise copy C from LDS to global
auto vgrad_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(vgrad_grid_desc_nblock_nperblock_oblock_operblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
vgrad_grid_desc_nblock_nperblock_oblock_operblock,
make_multi_index(block_work_idx_n, 0, block_work_idx[I1], 0),
c_element_op};
// shuffle: threadwise copy C from VGPR to LDS
auto kgrad_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
SElementwiseOperation,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
scale_rp_dropout};
// shuffle: blockwise copy C from LDS to global
auto kgrad_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(kgrad_grid_desc_nblock_nperblock_oblock_operblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
kgrad_grid_desc_nblock_nperblock_oblock_operblock,
make_multi_index(block_work_idx_n, 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<NXdlPerWave, Gemm1NXdlPerWave, 1, 1, 1, N2, 1, N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
1,
N2,
1,
N4>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, NPerBlock, 1, Gemm1NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
// dK
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
kgrad_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
kgrad_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
kgrad_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
kgrad_grid_desc_nblock_nperblock_oblock_operblock,
kgrad_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
kgrad_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
kgrad_grid_desc_nblock_nperblock_oblock_operblock, c_global_step);
}
});
// dV
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
vgrad_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
vgrad_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
vgrad_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
vgrad_grid_desc_nblock_nperblock_oblock_operblock,
vgrad_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
vgrad_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
vgrad_grid_desc_nblock_nperblock_oblock_operblock, c_global_step);
}
});
}
}
};
} // namespace ck
......@@ -917,7 +917,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
block_sync_lds();
}
do
{
auto n_block_data_idx_on_grid =
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace ck {
template <typename InputDataType,
typename FloatD,
typename CGridDesc_M_N,
typename DGridDesc_M,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock>
struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto WaveSize = 64;
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto y_grid_desc_mblock_mperblock_oblock_operblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return y_grid_desc_mblock_mperblock_oblock_operblock;
}
__host__ __device__ static constexpr auto
MakeORSGridDescriptor_MBlock_MPerBlock(const DGridDesc_M& d_grid_desc_m)
{
const index_t M = d_grid_desc_m.GetLength(I0);
const index_t MBlock = M / MPerBlock;
const auto d_grid_desc_mblock_mperblock = transform_tensor_descriptor(
d_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
return d_grid_desc_mblock_mperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
using YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static_assert(BlockSize_ == BlockSliceLength_M_);
static constexpr auto ThreadSliceLength_M = Number<1>{};
static constexpr index_t SrcScalarPerVector = 16 / sizeof(InputDataType);
static constexpr auto ThreadClusterLength_O = Number<1>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<BlockSliceLength_O_>{};
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatD,
ThreadSliceLength_M * ThreadSliceLength_O,
true>;
using DstBufType = StaticBuffer<AddressSpaceEnum::Vgpr, FloatD, ThreadSliceLength_M, true>;
};
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, NPerBlock>;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
return MPerBlock * sizeof(FloatD);
}
__device__ static void test() {}
template <typename Block2CTileMap>
__device__ static void Run(const InputDataType* __restrict__ p_y_grid,
const InputDataType* __restrict__ p_ygrad_grid,
FloatD* __restrict__ p_d_grid,
void* __restrict__ p_shared,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
const DGridDesc_M& d_grid_desc_m,
const Block2CTileMap& block_2_ctile_map)
{
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_m.GetElementSpaceSize());
// divide block work by [M, O]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I0),
y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I2))))
{
return;
}
const index_t block_work_idx_m = block_work_idx[I0];
constexpr auto d_thread_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1));
constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple(
I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc =
make_cluster_descriptor(Sequence<I1,
YDotYGrad_M_O::ThreadClusterLength_M,
I1,
YDotYGrad_M_O::ThreadClusterLength_O>{},
Sequence<0, 1, 2, 3>{});
const auto y_thread_cluster_idx =
y_thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
const auto y_thread_data_on_grid_idx =
make_multi_index(
block_work_idx_m, I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
y_thread_data_on_block_idx;
// performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
InputDataType,
FloatD,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>,
3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */,
false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx);
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{};
auto y_dot_ygrad_block_accum_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatD*>(p_shared), MPerBlock);
// clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear();
y_dot_ygrad_block_accum_buf.Clear();
index_t oblock_idx = 0;
do
{
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
y_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
y_thread_buf);
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
ygrad_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
ygrad_thread_buf);
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
static_for<0, YDotYGrad_M_O::ThreadSliceLength_O, 1>{}([&](auto iO) {
constexpr auto offset =
y_thread_desc_m0_m1_o0_o1.CalculateOffset(make_multi_index(I0, iM, I0, iO));
y_dot_ygrad_thread_accum_buf(iM) +=
y_thread_buf[Number<offset>{}] * ygrad_thread_buf[Number<offset>{}];
});
});
yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
make_multi_index(0, 0, 1, 0));
oblock_idx++;
} while(oblock_idx < y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I2));
auto d_grid_desc_mblock_mperblock = MakeORSGridDescriptor_MBlock_MPerBlock(d_grid_desc_m);
auto d_thread_copy_vgpr_to_global =
ThreadwiseTensorSliceTransfer_v1r3<FloatD,
FloatD,
decltype(d_thread_desc_mblock_mrepeat_mwave_mperxdl),
decltype(d_grid_desc_mblock_mperblock),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1, 1>,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
false>{
d_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx_m, // mblock
get_thread_local_1d_id()), // mperblock
ck::tensor_operation::element_wise::PassThrough{}};
// copy from VGPR to Global
d_thread_copy_vgpr_to_global.Run(d_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, I0),
y_dot_ygrad_thread_accum_buf,
d_grid_desc_mblock_mperblock,
d_grid_buf);
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment