Commit f752739c authored by danyao12's avatar danyao12
Browse files

Merge branch 'mha-train-develop' into mha-train-ldsbypass

parents b3a96764 26fa4782
......@@ -22,7 +22,6 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#include "ck/library/utility/numeric.hpp"
namespace ck {
namespace tensor_operation {
......
......@@ -10,7 +10,6 @@
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
// #include "ck/tensor_operation/gpu/device/device_batched_multihead_attention_backward.hpp" // TODO
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
......@@ -21,8 +20,6 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
......
......@@ -10,7 +10,6 @@
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
// #include "ck/tensor_operation/gpu/device/device_batched_multihead_attention_backward.hpp" // TODO
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
......@@ -21,8 +20,6 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
......
......@@ -22,8 +22,6 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
......
......@@ -10,7 +10,6 @@
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
// #include "ck/tensor_operation/gpu/device/device_batched_multihead_attention_backward.hpp" // TODO
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
......@@ -22,8 +21,6 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
......
......@@ -21,8 +21,6 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
......
......@@ -21,8 +21,6 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
......
......@@ -24,6 +24,7 @@ namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename D0DataType,
typename GemmAccDataType,
typename GroupKernelArg,
typename AElementwiseOperation,
......@@ -100,6 +101,17 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
const D0DataType* tmp_p_d0_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
}
if constexpr(Deterministic)
{
for(index_t i = 0; i < num_blocks_per_batch; i++)
......@@ -107,6 +119,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
tmp_p_d0_grid,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr
......@@ -124,9 +137,10 @@ __global__ void
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
......@@ -144,6 +158,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
tmp_p_d0_grid,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
......@@ -160,9 +175,10 @@ __global__ void
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
......@@ -233,6 +249,7 @@ template <index_t NumDimG,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t DropoutStep,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
......@@ -247,6 +264,7 @@ template <index_t NumDimG,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t Acc0BiasTransferSrcScalarPerVector,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
......@@ -258,6 +276,7 @@ template <index_t NumDimG,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t Acc1BiasTransferSrcScalarPerVector,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
......@@ -285,11 +304,10 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
using D0DataType = Acc0BiasDataType;
using D1DataType = Acc1BiasDataType;
// TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
static_assert(std::is_void<Acc1BiasDataType>::value, "Acc1 Bias addition is unimplemented");
#if 0
// TODO ANT: use alias
......@@ -392,8 +410,27 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
}
}
static auto
MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
}
static auto
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
......@@ -401,6 +438,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
......@@ -426,12 +464,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const D0GridDesc_G_M_N& d0_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
d0_grid_desc_g_m_n_(d0_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
......@@ -449,6 +489,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const
{
return d0_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
......@@ -472,6 +517,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
private:
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
......@@ -482,6 +528,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
ADataType, // TODO: distinguish A/B datatype
Acc0BiasDataType,
ZDataType,
GemmDataType,
GemmAccDataType,
......@@ -496,6 +543,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
D0GridDesc_M_N,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
ZGridDesc_M_N,
......@@ -515,6 +563,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
DropoutStep,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
......@@ -531,6 +580,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
Acc0BiasTransferSrcScalarPerVector,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
......@@ -543,6 +593,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Acc1BiasTransferSrcScalarPerVector,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec != MaskingSpecialization::MaskDisabled,
......@@ -555,6 +606,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
const D0DataType* p_d0_grid_;
const B1DataType* p_b1_grid_;
CDataType* p_c_grid_;
ZDataType* p_z_grid_;
......@@ -563,11 +615,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_;
ZGridDesc_M_N z_grid_desc_m_n_;
LSEGridDesc_M lse_grid_desc_m_;
......@@ -600,6 +654,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// for gridwise gemm check
CGridDesc_M_N c_grid_desc_m_n_;
// raw data
std::vector<ck::index_t> d0_n_length_stride_;
};
// Argument
......@@ -612,9 +669,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec,
std::vector<const void*> p_acc0_biases_vec,
std::vector<const void*> p_acc1_biases_vec,
std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
......@@ -628,28 +685,28 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_element_op_{b1_element_op},
c_element_op_{c_element_op}
{
ignore = p_acc1_biases_vec;
// TODO ANT: implement bias addition
group_count_ = problem_desc_vec.size();
if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size()))
group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size() &&
(group_count_ == p_acc0_biases_vec.size() || p_acc0_biases_vec.size() == 0)))
{
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
}
if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size()))
{
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
}
grid_size_ = 0;
index_t z_random_matrix_offset = 0;
for(std::size_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_d0_grid = p_acc0_biases_vec.size() > 0
? static_cast<const D0DataType*>(p_acc0_biases_vec[i])
: nullptr;
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]);
......@@ -660,12 +717,31 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
is_lse_storing_ = false;
}
const auto& problem_desc = problem_desc_vec[i];
const auto& problem_desc = problem_desc_vec[i];
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
std::vector<index_t> tmp_d0_gs_ms_ns_lengths;
std::vector<index_t> tmp_d0_gs_ms_ns_strides;
if constexpr(!is_same<D0DataType, void>::value)
{
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_biases_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_biases_gs_ms_ns_strides;
}
else
{
tmp_d0_gs_ms_ns_lengths = {1, 1, 1, 1};
tmp_d0_gs_ms_ns_strides = {0, 0, 0, 0};
}
const D0GridDesc_M_N d0_grid_desc_m_n{DeviceOp::MakeD0GridDescriptor_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides)};
const auto d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0_grid_desc_m_n);
const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
......@@ -679,6 +755,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
......@@ -690,12 +768,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n);
// typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(
z_grid_desc_m_n);
const index_t BlockStart = grid_size_;
......@@ -710,6 +784,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k,
b_grid_desc_g_n_k,
d0_grid_desc_g_m_n,
b1_grid_desc_g_n_k,
c_grid_desc_g_m_n,
z_grid_desc_g_m_n,
......@@ -721,18 +796,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
grid_size_ += grid_size_grp;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
// so on
if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias &&
problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias &&
problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias &&
problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias))
{
throw std::runtime_error(
"wrong! number of biases in function argument does not "
"match that in template argument");
}
const auto raw_m_padded = GridwiseGemm::GetPaddedSize(
problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1]);
const auto raw_n_padded = GridwiseGemm::GetPaddedSize(
......@@ -740,15 +803,17 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
group_kernel_args_.push_back({p_a_grid,
p_b_grid,
p_d0_grid,
p_b1_grid,
p_c_grid,
p_z_grid,
p_lse_grid,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
z_grid_desc_m_n,
lse_grid_desc_m,
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
......@@ -764,6 +829,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
z_random_matrix_offset =
z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count;
// for check
std::vector<ck::index_t> d0_n_length_stride;
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_strides[NumDimG + NumDimM]);
group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
......@@ -777,7 +847,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]},
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n});
c_grid_desc_m_n,
d0_n_length_stride});
}
is_dropout_ = p_dropout > 0.0; //
......@@ -846,6 +917,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
D0DataType,
GemmAccDataType,
GroupKernelArg,
AElementwiseOperation,
......@@ -997,6 +1069,20 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
return false;
}
if constexpr(!is_same<D0DataType, void>::value)
{
if(device_arg.d0_n_length_stride_[1] == 1 &&
device_arg.d0_n_length_stride_[0] % Acc0BiasTransferSrcScalarPerVector != 0)
{
return false;
}
if(device_arg.d0_n_length_stride_[1] != 1 &&
Acc0BiasTransferSrcScalarPerVector != 1)
{
return false;
}
}
// Check if having main loop
const auto K = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) *
kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......@@ -1077,9 +1163,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec,
std::vector<const void*> p_acc0_biases_vec,
std::vector<const void*> p_acc1_biases_vec,
std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
......@@ -1116,9 +1202,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec,
std::vector<const void*> p_acc0_biases_vec,
std::vector<const void*> p_acc1_biases_vec,
std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
......
......@@ -1533,8 +1533,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
z_tensor_buffer;
z_tensor_buffer.Clear();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
......@@ -1966,16 +1966,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
// P_dropped
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(z_tensor_buffer),
true,
decltype(n0),
decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
s_slash_p_thread_buf, ph, z_tensor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
......
......@@ -1473,8 +1473,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
z_tensor_buffer;
z_tensor_buffer.Clear();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
......@@ -1865,16 +1865,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
// P_dropped
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(z_tensor_buffer),
true,
decltype(n0),
decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
s_slash_p_thread_buf, ph, z_tensor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
......
......@@ -110,6 +110,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// C desc for source in blockwise copy
......@@ -119,10 +124,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
......@@ -136,9 +140,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
}
__device__ static auto GetGemm0WaveIdx()
......@@ -542,9 +544,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
BBlockDesc_BK0_N_BK1{});
}
static constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
......@@ -646,8 +646,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack =
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
static constexpr index_t GemmKPack = mfma.group_size;
static constexpr index_t GemmMWave = Gemm0NWaves; // 4 // 4
static constexpr index_t GemmNWave = Gemm0MWaves; // 1 // 1
......@@ -770,9 +769,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave; // 1 // 1
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl; // 1 // 1
static constexpr index_t GemmKLoop = Gemm2_K / Sum_K; // 2 // 2
static constexpr index_t GemmKPack =
math::max(A_K1, MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t GemmKPack = math::max(A_K1, mfma.k_per_blk);
static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t B_K2 =
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}.K0PerXdlops; // 2
static constexpr index_t B_K1 = Sum_K / B_K2 / B_K3; // 4
......@@ -1570,8 +1568,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
z_tensor_buffer;
z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
......@@ -1759,7 +1757,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global
if constexpr(IsDropout)
{
......@@ -1774,23 +1771,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
decltype(z_tensor_buffer),
decltype(DropoutTile),
true>(s_slash_p_thread_buf,
ph,
global_elem_id,
z_tensor_buffer,
raw_n_padded);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
......@@ -1806,15 +1807,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
decltype(DropoutTile),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
......
......@@ -121,6 +121,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
......@@ -133,10 +138,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
......@@ -150,9 +154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
}
__device__ static auto GetGemm0WaveIdx()
......@@ -522,9 +524,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
......@@ -657,8 +657,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack =
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
static constexpr index_t GemmKPack = mfma.group_size;
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
......@@ -709,9 +708,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr index_t GemmMWave = BlockSize / get_warp_size() / GemmNWave;
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave;
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl;
static constexpr index_t GemmKPack =
math::max(math::lcm(A_K1, B_K1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t GemmKPack = math::max(math::lcm(A_K1, B_K1), mfma.k_per_blk);
using BBlockSliceLengths = Sequence<B_K0, Gemm2_N, B_K1>;
using BThreadClusterLengths =
......@@ -1554,8 +1551,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
z_tensor_buffer;
z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
......@@ -1722,7 +1719,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global
if constexpr(IsDropout)
{
......@@ -1737,23 +1733,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
decltype(z_tensor_buffer),
decltype(DropoutTile),
true>(s_slash_p_thread_buf,
ph,
global_elem_id,
z_tensor_buffer,
raw_n_padded);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
......@@ -1769,14 +1769,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
decltype(DropoutTile),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
......
......@@ -109,6 +109,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// C desc for source in blockwise copy
......@@ -118,10 +123,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
......@@ -135,9 +139,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
}
__device__ static auto GetGemm0WaveIdx()
......@@ -563,9 +565,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
BBlockDesc_BK0_N_BK1{});
}
static constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
......@@ -667,8 +667,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack =
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
static constexpr index_t GemmKPack = mfma.group_size;
static constexpr index_t GemmMWave = Gemm0NWaves; // 4 // 4
static constexpr index_t GemmNWave = Gemm0MWaves; // 1 // 1
......@@ -791,9 +790,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave; // 1 // 1
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl; // 1 // 1
static constexpr index_t GemmKLoop = Gemm2_K / Sum_K; // 2 // 2
static constexpr index_t GemmKPack =
math::max(A_K1, MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t GemmKPack = math::max(A_K1, mfma.k_per_blk);
static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t B_K2 =
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}.K0PerXdlops; // 2
static constexpr index_t B_K1 = Sum_K / B_K2 / B_K3; // 4
......@@ -1621,8 +1619,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
z_tensor_buffer;
z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
......@@ -1946,7 +1944,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global
if constexpr(IsDropout)
{
......@@ -1961,23 +1958,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
decltype(z_tensor_buffer),
decltype(DropoutTile),
true>(s_slash_p_thread_buf,
ph,
global_elem_id,
z_tensor_buffer,
raw_n_padded);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
......@@ -1993,15 +1994,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
decltype(DropoutTile),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
......
......@@ -120,6 +120,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
......@@ -132,10 +137,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
......@@ -149,9 +153,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
}
__device__ static auto GetGemm0WaveIdx()
......@@ -543,9 +545,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
......@@ -678,8 +678,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack =
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
static constexpr index_t GemmKPack = mfma.group_size;
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
......@@ -730,9 +729,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr index_t GemmMWave = BlockSize / get_warp_size() / GemmNWave;
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave;
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl;
static constexpr index_t GemmKPack =
math::max(math::lcm(A_K1, B_K1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t GemmKPack = math::max(math::lcm(A_K1, B_K1), mfma.k_per_blk);
using BBlockSliceLengths = Sequence<B_K0, Gemm2_N, B_K1>;
using BThreadClusterLengths =
......@@ -1582,8 +1579,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
z_tensor_buffer;
z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
......@@ -1862,7 +1859,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global
if constexpr(IsDropout)
{
......@@ -1877,23 +1873,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
decltype(z_tensor_buffer),
decltype(DropoutTile),
true>(s_slash_p_thread_buf,
ph,
global_elem_id,
z_tensor_buffer,
raw_n_padded);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
......@@ -1909,14 +1909,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
decltype(DropoutTile),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
......
......@@ -873,8 +873,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
z_tensor_buffer;
z_tensor_buffer.Clear();
// z matrix global desc
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -1022,16 +1022,16 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
{
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
decltype(z_tenor_buffer),
decltype(z_tensor_buffer),
false,
decltype(n0),
decltype(i)>(
acc_thread_buf, ph, z_tenor_buffer);
acc_thread_buf, ph, z_tensor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
......
......@@ -25,6 +25,7 @@ namespace ck {
*
*/
template <typename FloatAB,
typename D0DataType,
typename ZDataType,
typename FloatGemm,
typename FloatGemmAcc,
......@@ -39,6 +40,7 @@ template <typename FloatAB,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename D0GridDesc_M_N,
typename B1GridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
typename ZGridDesc_M_N,
......@@ -58,6 +60,7 @@ template <typename FloatAB,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t DropoutStepValue,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
......@@ -74,6 +77,7 @@ template <typename FloatAB,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t BBlockLdsExtraN,
index_t D0BlockTransferSrcScalarPerVector,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
......@@ -86,6 +90,7 @@ template <typename FloatAB,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t D1BlockTransferSrcScalarPerVector,
LoopScheduler LoopSched,
bool PadN,
bool MaskOutUpperTriangle,
......@@ -93,6 +98,10 @@ template <typename FloatAB,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{
static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4");
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......@@ -104,6 +113,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
static constexpr auto WaveSize = 64;
......@@ -121,54 +132,76 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
static constexpr auto DropoutMThread = DropoutTile; // 16
static constexpr auto DropoutTilePerXdl = NPerXdl / DropoutTile; // 2
static constexpr auto DropoutStep = Number<DropoutStepValue>{}; // 1 2 4
static constexpr auto DropoutNRepeat =
Number<math::integer_divide_ceil(DropoutStep, DropoutTilePerXdl)>{}; // 1 1 2
static constexpr auto DropoutGroupPerTile =
Number<mfma.num_groups_per_blk / DropoutTilePerXdl>{}; // 2
static constexpr auto DropoutStepPerXdl =
Number<math::min(DropoutStep, DropoutTilePerXdl)>{}; // 1 2 2
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// C desc for source in gridwise copy
__host__ __device__ static constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(
const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
const auto M0 = M / MPerBlock;
const auto N0 = N / (DropoutNRepeat * NPerXdl);
constexpr auto M1 = MXdlPerWave;
constexpr auto N1 = DropoutNRepeat;
constexpr auto M2 = Gemm0MWaves;
constexpr auto N2 = Gemm0NWaves;
constexpr auto M3 = DropoutTilePerXdl;
constexpr auto N3 = DropoutStepPerXdl;
constexpr auto M4 = DropoutTile;
constexpr auto N4 = DropoutGroupPerTile;
constexpr auto N5 = mfma.num_input_blks;
constexpr auto N6 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4, N5, N6))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
make_tuple(Sequence<0, 2, 4, 6, 8>{}, Sequence<1, 3, 5, 7, 9, 10, 11>{}));
}
__host__ __device__ static constexpr auto GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
__host__ __device__ static constexpr auto
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5()
{
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M0 = MXdlPerWave;
constexpr auto M1 = Gemm0MWaves;
constexpr auto N1 = Gemm0NWaves;
constexpr auto M2 = MPerXdl;
constexpr auto N2 = mfma.num_groups_per_blk;
constexpr auto N3 = mfma.num_input_blks;
constexpr auto N4 = mfma.group_size;
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
make_naive_tensor_descriptor_packed(make_tuple(M0, I1, M1, N1, M2, N2, N3, N4));
return z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4;
constexpr auto M0 = MXdlPerWave;
constexpr auto N0 = DropoutNRepeat;
constexpr auto M1 = Gemm0MWaves;
constexpr auto N1 = Gemm0NWaves;
constexpr auto M2 = DropoutTilePerXdl;
constexpr auto N2 = DropoutStepPerXdl;
constexpr auto M3 = DropoutTile;
constexpr auto N3 = DropoutGroupPerTile;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(M0, N0, M1, N1, M2, N2, M3, N3, N4, N5));
return z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
}
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
}
__device__ static auto GetGemm0WaveIdx()
......@@ -407,14 +440,36 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
c_grid_desc_m_n);
}
// D0 desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N& d0_grid_desc_m_n)
{
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
using D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0GridDesc_M_N{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(ZGridDesc_M_N{}))>;
struct SharedMemTrait
{
......@@ -452,10 +507,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
// LDS allocation for Z shuffle in LDS
static constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
static constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5();
static constexpr auto z_shuffle_block_space_size =
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize();
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize();
};
template <bool HasMainKBlockLoop,
......@@ -465,6 +520,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const D0DataType* __restrict__ p_d0_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
ZDataType* __restrict__ p_z_grid,
......@@ -477,11 +533,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
const LSEGridDesc_M& lse_grid_desc_m,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
......@@ -603,9 +661,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
......@@ -765,8 +821,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma.group_size;
constexpr index_t Gemm1KPack = mfma.group_size;
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
......@@ -891,37 +946,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
// z is random number matrix for dropout verify
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // RegisterNum
constexpr auto z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4, // RegisterNum
I1)); // I1
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
// bias (d matrix)
constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockId
m0, // MRepeat
I1, // NRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
......@@ -929,29 +961,106 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
n3, // NInputNum
n4)); // RegisterNum
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
auto d0_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<D0DataType,
D0DataType,
decltype(d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
D0BlockTransferSrcScalarPerVector,
1,
false>(d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0)); // register number
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
DropoutNRepeat, // NRepeat
m1, // MWaveId
n1, // NWaveId
I1,
DropoutStepPerXdl,
m2,
DropoutGroupPerTile,
n3,
n4)); // RegisterNum
constexpr auto ZM0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto ZN0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto ZM1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto ZN1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto ZM2 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto ZN2 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto ZN3 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto ZN4 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
DropoutNRepeat, // NRepeat
m1, // MWaveId
n1, // NWaveId
I1,
DropoutStepPerXdl,
DropoutGroupPerTile,
n3,
n4, // RegisterNum
m2));
// z is random number matrix for dropout verify
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockId
m0, // MRepeat
DropoutNRepeat, // NRepeat
m1, // MWaveId
n1, // NWaveId
I1,
DropoutStepPerXdl,
m2,
DropoutGroupPerTile,
n3,
n4)); // RegisterNum
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 =
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5();
constexpr auto ZM0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I0); // 1
constexpr auto ZN0 =
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I1); // 1 1 2
constexpr auto ZM1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I2); // 4
constexpr auto ZN1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I3); // 1
constexpr auto ZM2 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I4); // 2
constexpr auto ZN2 =
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I5); // 1 2 2
constexpr auto ZM3 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I6); // 16
constexpr auto ZN3 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I7); // 2
constexpr auto ZN4 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I8); // 2
constexpr auto ZN5 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetLength(I9); // 4
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
transform_tensor_descriptor(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(make_pass_through_transform(ZM0),
make_pass_through_transform(ZN0),
make_pass_through_transform(ZM1),
make_pass_through_transform(ZN1),
make_unmerge_transform(make_tuple(ZM2 / ZN4, ZN4)),
make_pass_through_transform(ZM2),
make_pass_through_transform(ZN2),
make_pass_through_transform(ZN3),
make_pass_through_transform(ZN4)),
make_unmerge_transform(make_tuple(ZM3 / ZN4 / ZN5, ZN4, ZN5)),
make_merge_transform_v3_division_mod(make_tuple(ZN3, ZN4, ZN5))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -959,115 +1068,130 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
Sequence<7, 8, 9>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4, 7>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<8>{}));
Sequence<6, 7, 8>{},
Sequence<9>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
ushort,
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize(),
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tensor_buffer;
z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6.GetElementSpaceSize());
auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ushort*>(p_shared),
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
auto z_tmp_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<m0, // MRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
Sequence<m0, // MRepeat
DropoutNRepeat, // NRepeat
m1, // MWaveId
n1, // NWaveId
I1,
DropoutStepPerXdl,
m2,
DropoutGroupPerTile,
n3,
n4>, // RegisterNum
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0, // MRepeat
0, // NRepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // NGroupIndex
wave_m_n_id[I0], // NInputIndex
true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, // MRepeat
0, // NRepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1] / DropoutMThread,
0,
wave_m_n_id[I1] % DropoutMThread,
0,
wave_m_n_id[I0],
0),
tensor_operation::element_wise::PassThrough{}};
auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ushort,
ushort,
decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
decltype(z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
Sequence<m0, I1, m1, n1, m2, n2, n3, n4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>,
8,
decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
Sequence<m0,
DropoutNRepeat,
m1,
n1,
I1,
DropoutStepPerXdl,
DropoutGroupPerTile,
n3,
n4,
m2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
1,
1,
true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(0, // MRepeat
0, // NRepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1] / ZN4,
wave_m_n_id[I1] / DropoutMThread,
0,
0,
wave_m_n_id[I0],
0,
wave_m_n_id[I1] % ZN4)};
wave_m_n_id[I1] % DropoutMThread)};
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
DropoutNRepeat, // NRepeat
m1, // MWaveId
n1, // NWaveId
I1,
DropoutStepPerXdl,
m2,
DropoutGroupPerTile,
n3,
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11>,
11, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
wave_m_n_id[I1] / DropoutMThread,
0,
wave_m_n_id[I1] % DropoutMThread,
0,
wave_m_n_id[I0],
0),
tensor_operation::element_wise::PassThrough{}};
......@@ -1163,14 +1287,41 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// add bias
if constexpr(!is_same<D0DataType, void>::value)
{
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
// get register
StaticBuffer<AddressSpaceEnum::Vgpr,
D0DataType,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
d0_thread_buf;
// load data from global
d0_threadwise_copy.Run(d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d0_grid_buf,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
d0_thread_buf);
// acc add bias
static_for<0, m0 * n0 * n2 * n4, 1>{}(
[&](auto i) { acc_thread_buf(i) += d0_thread_buf[i]; });
d0_threadwise_copy.MoveSrcSliceWindow(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
// softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
blockwise_softmax.Run(acc_thread_buf, workspace_buf);
constexpr auto position_offset = N3 * N4;
constexpr auto iterator_offset = n2 * n3 * n4;
constexpr auto iterator_offset = Number<8 * DropoutStep>{};
constexpr auto iterator_step = Number<n0 * n1 * n2 * n3 * n4 / 8 / DropoutStep>{};
if constexpr(IsDropout) // dropout
{
......@@ -1187,49 +1338,44 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
n_global; // unique element global 1d id
blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tensor_buffer),
decltype(n0),
decltype(position_offset)>(
decltype(iterator_step),
decltype(DropoutTile)>(
ph, global_elem_id, z_tensor_buffer);
z_tmp_thread_copy_vgpr_to_lds.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_block_buf);
z_tmp_thread_copy_vgpr_to_lds.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_block_buf);
z_shuffle_thread_copy_lds_to_vgpr.Run(
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_block_buf,
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer);
blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf),
decltype(z_tensor_buffer),
decltype(n0),
decltype(iterator_step),
decltype(i)>(acc_thread_buf,
z_tensor_buffer);
// save z to global
if(p_z_grid)
if(p_z_grid && (gemm1_n_block_data_idx_on_grid == 0))
{
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0));
}
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
// TODO: may convert to log domain
......@@ -1350,7 +1496,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static_for<0, MXdlPerWave, 1>{}(
[&](auto I) { lse_thread_buf(I) = running_max(I) + math::log(running_sum(I)); });
if(get_lane_local_1d_id() < AccM2)
if((get_lane_local_1d_id() < AccM2) && (gemm1_n_block_data_idx_on_grid == 0))
{
static_for<0, MXdlPerWave, 1>{}([&](auto I) {
// copy from VGPR to Global
......
......@@ -10,6 +10,7 @@
#include <type_traits>
#include <utility>
#include "ck/utility/type_convert.hpp"
#include "ck/utility/data_type.hpp"
namespace ck {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment