Commit d6fd270e authored by ltqin's avatar ltqin
Browse files

send z matrix to kernel

parent 4307a754
......@@ -49,6 +49,7 @@ using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
......@@ -61,6 +62,7 @@ using DataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
using ZDataType = U16;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
......@@ -92,6 +94,7 @@ using DeviceGemmInstance =
NumDimK,
NumDimO,
DataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
......@@ -330,6 +333,12 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
......@@ -341,6 +350,7 @@ int run(int argc, char* argv[])
Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
......@@ -348,10 +358,12 @@ int run(int argc, char* argv[])
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "z_gs_ms_ks: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{-1});
switch(init_method)
{
case 0: break;
......@@ -417,6 +429,7 @@ int run(int argc, char* argv[])
// calculate y & log-sum-exp beforehand
Tensor<DataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<DataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<DataType> p_g_m_n({BatchCount, M, N});
......@@ -427,6 +440,8 @@ int run(int argc, char* argv[])
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
lse_gs_ms.ForEach(
......@@ -442,6 +457,7 @@ int run(int argc, char* argv[])
// qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize());
......@@ -452,6 +468,7 @@ int run(int argc, char* argv[])
q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data());
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data());
......@@ -464,6 +481,7 @@ int run(int argc, char* argv[])
auto argument = gemm.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
......@@ -477,6 +495,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
k_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
v_gs_os_ns_lengths,
v_gs_os_ns_strides,
y_gs_ms_os_lengths,
......
......@@ -29,6 +29,7 @@ namespace device {
template <typename GridwiseGemm,
typename DataType,
typename ZDataType,
typename LSEDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
......@@ -37,6 +38,7 @@ template <typename GridwiseGemm,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename ZGridDesc_M_N,
typename B1GridDesc_BK0_N_BK1,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M,
......@@ -50,9 +52,10 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2(
const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
const DataType* __restrict__ p_b1_grid,
const DataType* __restrict__ p_c_grid,
const LSEDataType* __restrict__ p_lse_grid,
......@@ -67,6 +70,7 @@ __global__ void
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDesc_M_N z_grid_desc_m_n,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -95,6 +99,8 @@ __global__ void
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
......@@ -107,6 +113,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_z_grid + z_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset,
......@@ -122,6 +129,7 @@ __global__ void
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
z_grid_desc_m_n,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
......@@ -163,6 +171,7 @@ template <index_t NumDimG,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename DataType,
typename ZDataType,
typename LSEDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
......@@ -441,6 +450,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
}
// Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec,
const std::vector<index_t>& z_gs_ms_ns_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
......@@ -501,9 +516,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using VGradGridDesc_N_O = decltype(MakeVGradGridDescriptor_N_O({}, {}));
using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
constexpr static auto make_MaskOutPredicate()
{
......@@ -522,11 +539,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
{
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n,
index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
BatchStrideLSE_(BatchStrideLSE)
......@@ -543,6 +562,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
......@@ -561,8 +585,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
private:
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
index_t BatchStrideLSE_;
};
......@@ -580,6 +606,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
ZGridDesc_M_N,
B1GridDesc_BK0_N_BK1,
YGridDesc_M_O,
LSEGridDesc_M,
......@@ -636,6 +663,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
Argument(
const DataType* p_a_grid,
const DataType* p_b_grid,
ZDataType* p_z_grid,
const DataType* p_b1_grid,
const DataType* p_c_grid, // for dS
const LSEDataType* p_lse_grid,
......@@ -649,6 +677,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
......@@ -669,6 +699,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_z_grid_{p_z_grid},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
p_lse_grid_{p_lse_grid},
......@@ -680,6 +711,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
......@@ -697,6 +729,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(y_grid_desc_m_o_)},
a_element_op_{a_element_op},
......@@ -721,6 +755,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
......@@ -752,8 +787,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
offset_ = std::get<1>(seeds);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
a_grid_desc_g_m_k_.GetLength(I1), b_grid_desc_g_n_k_.GetLength(I1));
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
// Print();
}
......@@ -785,6 +819,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
// pointers
const DataType* p_a_grid_;
const DataType* p_b_grid_;
ZDataType* p_z_grid_;
const DataType* p_b1_grid_;
const DataType* p_c_grid_;
const LSEDataType* p_lse_grid_;
......@@ -796,6 +831,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_;
......@@ -807,6 +843,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
......@@ -865,9 +902,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1<
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2<
GridwiseGemm,
DataType,
ZDataType,
LSEDataType,
AElementwiseOperation,
BElementwiseOperation,
......@@ -876,6 +914,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::ZGridDesc_M_N,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M,
......@@ -893,6 +932,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_z_grid_,
arg.p_b1_grid_,
arg.p_c_grid_,
arg.p_lse_grid_,
......@@ -907,6 +947,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.z_grid_desc_m_n_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_,
......@@ -1031,6 +1072,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
static auto MakeArgument(
const DataType* p_a,
const DataType* p_b,
ZDataType* p_z,
const DataType* p_b1,
const DataType* p_c,
const LSEDataType* p_lse,
......@@ -1044,6 +1086,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
......@@ -1065,6 +1109,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
{
return Argument{p_a,
p_b,
p_z,
p_b1,
p_c,
p_lse,
......@@ -1078,6 +1123,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
......@@ -1103,6 +1150,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
void* p_z,
const void* p_b1,
const void* p_c,
const void* p_lse,
......@@ -1116,6 +1164,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
......@@ -1137,6 +1187,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
{
return std::make_unique<Argument>(static_cast<const DataType*>(p_a),
static_cast<const DataType*>(p_b),
static_cast<ZDataType*>(p_z),
static_cast<const DataType*>(p_b1),
static_cast<const DataType*>(p_c),
static_cast<const LSEDataType*>(p_lse),
......@@ -1150,6 +1201,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
......
......@@ -32,6 +32,7 @@ template <typename DataType,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename QGridDesc_K0_M_K1,
typename KGridDesc_K0_N_K1,
typename ZGridDesc_M_N,
typename VGridDesc_N0_O_N1,
typename CGridDesc_M_N,
typename LSEGridDesc_M,
......@@ -118,14 +119,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// C desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M, const index_t N)
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const ZGridDesc_M_N& z_grid_desc_m_n)
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(M, N)),
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
......@@ -390,8 +394,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
using CGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 =
remove_cvref_t<decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(0, 0))>;
using CGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>;
// S / dP Gemm (type 1 rcr)
struct Gemm0
......@@ -1121,6 +1125,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const DataType* __restrict__ p_q_grid,
const DataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid,
const DataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid,
......@@ -1136,6 +1141,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const ZGridDesc_M_N& z_grid_desc_m_n,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
......@@ -1149,6 +1155,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
FloatGemmAcc rp_dropout,
ck::philox& ph)
{
ignore = p_z_grid;
ignore = z_grid_desc_m_n;
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto k_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -1418,6 +1426,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// z vgpr copy to global
//
// z matrix threadwise desc
if(get_thread_global_1d_id() == 0)
{
printf("m0: %d n0: %d m1: %d n1: %d m2: %d n2: %d n3: %d n4: %d \n",
m0.value, // MRepeat
n0.value, // NRepeat
m1.value, // MWaveId
n1.value, // NWaveId
m2.value, // MPerXdl
n2.value, // NGroupNum
n3.value, // NInputNum
n4.value);
}
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
......@@ -1430,14 +1450,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n3, // NInputNum
n4)); // registerNum
// z matrix global desc
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n);
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
if(get_thread_global_1d_id() == 191)
{
printf("wave_id{ %d, %d, %d}, wave_m_n_id{%d, %d}\n",
wave_id[I0],
wave_id[I1],
wave_id[I2],
wave_m_n_id[I0],
wave_m_n_id[I1]);
}
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment