Commit 701879d0 authored by letaoqin's avatar letaoqin
Browse files

v1 device complete

parent 0539dbcd
...@@ -129,7 +129,7 @@ using DeviceGemmInstance = ...@@ -129,7 +129,7 @@ using DeviceGemmInstance =
#elif(DIM <= 128) #elif(DIM <= 128)
// clang-format off // clang-format off
using DeviceGemmInstance = using DeviceGemmInstance =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | |
......
...@@ -25,7 +25,7 @@ Kernel outputs: ...@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8. #define DIM 32 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -104,11 +104,11 @@ static constexpr bool Deterministic = false; ...@@ -104,11 +104,11 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
// clang-format off // clang-format off
using DeviceGemmInstance = using DeviceGemmInstance =
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | // ########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 4, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>; ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, InputDataType, OutputDataType, GemmDataType, ZDataType, LSEDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, ShuffleDataType, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 32, 32, 8, 8, 2, 32, 32, 4, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 4, 1, 1, S<1, 64, 1, 4>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>;
// clang-format on // clang-format on
#elif(DIM <= 64) #elif(DIM <= 64)
// clang-format off // clang-format off
......
...@@ -26,6 +26,7 @@ namespace device { ...@@ -26,6 +26,7 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename InputDataType, typename InputDataType,
typename D0DataType,
typename OutputDataType, typename OutputDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
...@@ -36,6 +37,7 @@ template <typename GridwiseGemm, ...@@ -36,6 +37,7 @@ template <typename GridwiseGemm,
typename CElementwiseOperation, typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename D0GridDescriptor_M0_N0_M1_M2_N1_M3,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
...@@ -54,6 +56,7 @@ __global__ void ...@@ -54,6 +56,7 @@ __global__ void
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1( kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1(
const InputDataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
const D0DataType* __restrict__ p_d0_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
const InputDataType* __restrict__ p_b1_grid, const InputDataType* __restrict__ p_b1_grid,
const InputDataType* __restrict__ p_c_grid, const InputDataType* __restrict__ p_c_grid,
...@@ -69,6 +72,7 @@ __global__ void ...@@ -69,6 +72,7 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
...@@ -114,6 +118,13 @@ __global__ void ...@@ -114,6 +118,13 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
for(index_t i = 0; i < nblock; i++) for(index_t i = 0; i < nblock; i++)
...@@ -121,6 +132,7 @@ __global__ void ...@@ -121,6 +132,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
tmp_p_d0_grid,
z_matrix_ptr, z_matrix_ptr,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
...@@ -137,6 +149,7 @@ __global__ void ...@@ -137,6 +149,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -156,6 +169,7 @@ __global__ void ...@@ -156,6 +169,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
tmp_p_d0_grid,
z_matrix_ptr, z_matrix_ptr,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
...@@ -172,6 +186,7 @@ __global__ void ...@@ -172,6 +186,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -188,6 +203,7 @@ __global__ void ...@@ -188,6 +203,7 @@ __global__ void
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_d0_grid;
ignore = p_z_grid; ignore = p_z_grid;
ignore = p_b1_grid; ignore = p_b1_grid;
ignore = p_c_grid; ignore = p_c_grid;
...@@ -203,6 +219,7 @@ __global__ void ...@@ -203,6 +219,7 @@ __global__ void
ignore = c_element_op; ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3; ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
ignore = b1_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -278,6 +295,7 @@ template <index_t NumDimG, ...@@ -278,6 +295,7 @@ template <index_t NumDimG,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN, bool BBlockLdsExtraN,
index_t D0BlockTransferSrcScalarPerVector,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -546,9 +564,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -546,9 +564,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return lse_grid_desc_mraw; return lse_grid_desc_mraw;
} }
} }
// D in Gemm0 C position
static auto MakeDGridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths_vec,
const std::vector<index_t>& d_gs_ms_ns_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths_vec, d_gs_ms_ns_strides_vec);
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
...@@ -558,6 +583,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -558,6 +583,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0GridDesc_M_N = decltype(MakeDGridDescriptor_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {})); using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
...@@ -581,14 +607,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -581,14 +607,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch() {}
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const D0GridDesc_G_M_N& d0_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
d0_grid_desc_g_m_n_(d0_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
...@@ -606,6 +635,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -606,6 +635,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const
{
return d0_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{ {
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
...@@ -629,6 +663,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -629,6 +663,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
...@@ -639,6 +674,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -639,6 +674,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
D0DataType,
OutputDataType, OutputDataType,
ZDataType, ZDataType,
GemmDataType, GemmDataType,
...@@ -654,6 +690,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -654,6 +690,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
KGridDesc_N_K, KGridDesc_N_K,
D0GridDesc_M_N,
ZGridDesc_M_N, ZGridDesc_M_N,
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
...@@ -690,6 +727,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -690,6 +727,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
true, true,
BBlockLdsExtraN, BBlockLdsExtraN,
D0BlockTransferSrcScalarPerVector,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -740,6 +778,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -740,6 +778,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::tuple<unsigned long long, unsigned long long> seeds) std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_d0_grid_{p_acc0_biases},
p_z_grid_{p_z_grid}, p_z_grid_{p_z_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -794,13 +833,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -794,13 +833,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())},
p_drop_{p_drop} p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
...@@ -820,6 +852,25 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -820,6 +852,25 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
GridwiseGemm::MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock( GridwiseGemm::MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(
y_grid_desc_m_o_); y_grid_desc_m_o_);
} }
if constexpr(!is_same<D0DataType, void>::value)
{
const auto d0_grid_desc_m_n = MakeDGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides);
}
compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
d0_grid_desc_g_m_n_,
z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize()));
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -860,6 +911,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -860,6 +911,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// pointers // pointers
const InputDataType* p_a_grid_; const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_; const InputDataType* p_b_grid_;
const D0DataType* p_d0_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const InputDataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_; const InputDataType* p_c_grid_;
...@@ -872,6 +924,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -872,6 +924,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
...@@ -882,6 +935,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -882,6 +935,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// batch offsets // batch offsets
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
...@@ -946,6 +1000,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -946,6 +1000,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1< kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
InputDataType, InputDataType,
D0DataType,
OutputDataType, OutputDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -956,6 +1011,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -956,6 +1011,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
...@@ -976,6 +1032,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -976,6 +1032,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_d0_grid_,
arg.p_z_grid_, arg.p_z_grid_,
arg.p_b1_grid_, arg.p_b1_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -991,6 +1048,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -991,6 +1048,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_, arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
......
...@@ -22,6 +22,7 @@ namespace ck { ...@@ -22,6 +22,7 @@ namespace ck {
template <typename InputDataType, template <typename InputDataType,
typename OutputDataType, typename OutputDataType,
typename D0DataType,
typename ZDataType, typename ZDataType,
typename GemmDataType, typename GemmDataType,
typename FloatGemmAcc, typename FloatGemmAcc,
...@@ -36,6 +37,7 @@ template <typename InputDataType, ...@@ -36,6 +37,7 @@ template <typename InputDataType,
typename QGridDesc_K0_M_K1, typename QGridDesc_K0_M_K1,
typename KGridDesc_K0_N_K1, typename KGridDesc_K0_N_K1,
typename KGridDesc_N_K, typename KGridDesc_N_K,
typename D0GridDesc_M_N,
typename ZGridDesc_M_N, typename ZGridDesc_M_N,
typename VGridDesc_O0_N_O1, typename VGridDesc_O0_N_O1,
typename YGridDesc_M_O, typename YGridDesc_M_O,
...@@ -72,6 +74,7 @@ template <typename InputDataType, ...@@ -72,6 +74,7 @@ template <typename InputDataType,
index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun, // ignored bool BThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t BBlockLdsExtraN, index_t BBlockLdsExtraN,
index_t D0BlockTransferSrcScalarPerVector,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -1219,44 +1222,76 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1219,44 +1222,76 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return math::max(p_slash_sgrad_bytes_end, softmax_bytes_end, c_block_bytes_end); return math::max(p_slash_sgrad_bytes_end, softmax_bytes_end, c_block_bytes_end);
} }
// D0
static constexpr auto D0M2 = Number<4>{};
static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2;
static constexpr auto D0M0 = Number<MPerBlock>{} / Number<MPerXdl>{};
__host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(const D0GridDesc_M_N& d0_grid_desc_m_n)
{
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 = transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, D0M0, D0M1, D0M2)),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 3, 5>{}, Sequence<1, 4>{}));
return d0_grid_desc_m0_n0_m1_m2_n1_m3;
}
using D0GridDescriptor_M0_N0_M1_M2_N1_M3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>;
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask, typename C0MatrixMask,
typename YGradGridDesc_O0_M_O1> typename YGradGridDesc_O0_M_O1>
__device__ static void Run(const InputDataType* __restrict__ p_q_grid, __device__ static void
const InputDataType* __restrict__ p_k_grid, Run(const InputDataType* __restrict__ p_q_grid,
ZDataType* __restrict__ p_z_grid, const InputDataType* __restrict__ p_k_grid,
const InputDataType* __restrict__ p_v_grid, const D0DataType* __restrict__ p_d0_grid,
const InputDataType* __restrict__ p_y_grid, ZDataType* __restrict__ p_z_grid,
const FloatLSE* __restrict__ p_lse_grid, const InputDataType* __restrict__ p_v_grid,
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_y_grid,
OutputDataType* __restrict__ p_qgrad_grid, const FloatLSE* __restrict__ p_lse_grid,
OutputDataType* __restrict__ p_kgrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
void* __restrict__ p_shared, OutputDataType* __restrict__ p_kgrad_grid,
const AElementwiseOperation& a_element_op, OutputDataType* __restrict__ p_vgrad_grid,
const BElementwiseOperation& b_element_op, void* __restrict__ p_shared,
const SElementwiseOperation& s_element_op, const AElementwiseOperation& a_element_op,
const B1ElementwiseOperation& b1_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const SElementwiseOperation& s_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1, const B1ElementwiseOperation& b1_element_op,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1, const CElementwiseOperation& c_element_op,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3& const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1, const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock& const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
y_grid_desc_mblock_mperblock_oblock_operblock, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const LSEGridDesc_M& lse_grid_desc_m, const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const YGradGridDesc_O0_M_O1& ygrad_grid_desc_o0_m_o1, const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
const Block2CTileMap& block_2_ctile_map, y_grid_desc_mblock_mperblock_oblock_operblock,
const C0MatrixMask& c0_matrix_mask, const LSEGridDesc_M& lse_grid_desc_m,
const float p_drop, const YGradGridDesc_O0_M_O1& ygrad_grid_desc_o0_m_o1,
ck::philox& ph, const Block2CTileMap& block_2_ctile_map,
const index_t z_random_matrix_offset, const C0MatrixMask& c0_matrix_mask,
const index_t raw_n_padded, const float p_drop,
const index_t block_idx_n) ck::philox& ph,
const index_t z_random_matrix_offset,
const index_t raw_n_padded,
const index_t block_idx_n)
{ {
ignore = p_d0_grid;
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = const ushort p_dropout_in_16bits =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment