Commit f82a220f authored by guangzlu's avatar guangzlu
Browse files

v4 pass

parent ff88ffa4
...@@ -766,7 +766,7 @@ int run(int argc, char* argv[]) ...@@ -766,7 +766,7 @@ int run(int argc, char* argv[])
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()), // set to nullptr
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
......
...@@ -32,7 +32,7 @@ Kernel outputs: ...@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8. #define DIM 32 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -710,17 +710,17 @@ int run(int argc, char* argv[]) ...@@ -710,17 +710,17 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 1000; // 512 ck::index_t M = 500; // 512
ck::index_t N = 1000; // 512 ck::index_t N = 500; // 512
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; // 54 ck::index_t G0 = 2; // 54
ck::index_t G1 = 6; // 16 ck::index_t G1 = 1; // 16
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
float p_drop = 0.0; float p_drop = 0.1;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -944,7 +944,7 @@ int run(int argc, char* argv[]) ...@@ -944,7 +944,7 @@ int run(int argc, char* argv[])
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), static_cast<ZDataType*>(z_fwd_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
...@@ -998,7 +998,7 @@ int run(int argc, char* argv[]) ...@@ -998,7 +998,7 @@ int run(int argc, char* argv[])
auto argument_bwd = gemm_bwd.MakeArgument( auto argument_bwd = gemm_bwd.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr static_cast<ZDataType*>(z_bwd_device_buf.GetDeviceBuffer()), // set to nullptr
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
...@@ -1399,20 +1399,20 @@ int run(int argc, char* argv[]) ...@@ -1399,20 +1399,20 @@ int run(int argc, char* argv[])
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData, pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
qgrad_gs_ms_ks_host_result.mData, qgrad_gs_ms_ks_host_result.mData,
"error", "error",
1e-2, 1e-3,
1e-2); 1e-3);
std::cout << "Checking kgrad:\n"; std::cout << "Checking kgrad:\n";
pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData, pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
kgrad_gs_ns_ks_host_result.mData, kgrad_gs_ns_ks_host_result.mData,
"error", "error",
1e-2, 1e-3,
1e-2); 1e-3);
std::cout << "Checking vgrad:\n"; std::cout << "Checking vgrad:\n";
pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData, pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
vgrad_gs_os_ns_host_result.mData, vgrad_gs_os_ns_host_result.mData,
"error", "error",
1e-2, 1e-3,
1e-2); 1e-3);
} }
return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1); return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
......
...@@ -145,14 +145,14 @@ struct BlockwiseDropout ...@@ -145,14 +145,14 @@ struct BlockwiseDropout
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8); ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
} }
ushort tmp_id[tmp_size]; // ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++) // for(int i = 0; i < philox_calls; i++)
{ //{
for(int j = 0; j < 4; j++) // for(int j = 0; j < 4; j++)
{ // {
tmp_id[i * 4 + j] = element_global_1d_id + i * 8; // tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
} // }
} //}
block_sync_lds(); block_sync_lds();
...@@ -162,7 +162,7 @@ struct BlockwiseDropout ...@@ -162,7 +162,7 @@ struct BlockwiseDropout
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset)); execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp_id[tmp_index]; z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
}); });
...@@ -208,17 +208,17 @@ struct BlockwiseDropout ...@@ -208,17 +208,17 @@ struct BlockwiseDropout
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8); ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
} }
ushort tmp_id[tmp_size]; // ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++) // for(int i = 0; i < philox_calls; i++)
{ //{
for(int j = 0; j < 4; j++) // for(int j = 0; j < 4; j++)
{ // {
tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw; // tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
} // }
} //}
block_sync_lds(); block_sync_lds();
...@@ -226,7 +226,7 @@ struct BlockwiseDropout ...@@ -226,7 +226,7 @@ struct BlockwiseDropout
static_for<0, MRepeat, 1>{}([&](auto iM) { static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
z_thread_buf(offset) = tmp_id[tmp_index]; z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
}); });
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -40,6 +41,7 @@ template <typename GridwiseGemm, ...@@ -40,6 +41,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
...@@ -73,6 +75,8 @@ __global__ void ...@@ -73,6 +75,8 @@ __global__ void
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -138,6 +142,7 @@ __global__ void ...@@ -138,6 +142,7 @@ __global__ void
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
...@@ -173,6 +178,7 @@ __global__ void ...@@ -173,6 +178,7 @@ __global__ void
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
...@@ -828,6 +834,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -828,6 +834,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5_N4(
z_grid_desc_m_n_);
// tmp z tensor for shuffle
// Tensor<ZDataType> z_tmp_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
// DeviceMem z_tmp_device_buf(sizeof(ZDataType) *
// z_tmp_gs_ms_ns.mDesc.GetElementSpaceSize());
// z_tmp_device_buf.ToDevice(z_tmp_gs_ms_ns.mData.data());
// p_z_tmp_grid_ = reinterpret_cast<ZDataType*>(z_tmp_device_buf.GetDeviceBuffer());
// Print(); // Print();
} }
...@@ -859,7 +877,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -859,7 +877,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// pointers // pointers
const InputDataType* p_a_grid_; const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_; const InputDataType* p_b_grid_;
// ZDataType* p_z_tmp_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const InputDataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_; const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_; const LSEDataType* p_lse_grid_;
...@@ -890,6 +911,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -890,6 +911,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_; c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_;
// block-to-c-tile map // block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -952,6 +976,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -952,6 +976,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
...@@ -986,6 +1011,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -986,6 +1011,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_, arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
......
...@@ -122,6 +122,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -122,6 +122,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
constexpr auto M3 = mfma.num_groups_per_blk; constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks; constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size; constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor( return transform_tensor_descriptor(
z_grid_desc_m_n, z_grid_desc_m_n,
make_tuple(make_unmerge_transform( make_tuple(make_unmerge_transform(
...@@ -132,6 +133,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -132,6 +133,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_tuple(Sequence<0, 2, 4, 6, 7, 8>{}, Sequence<1, 3, 5, 9>{})); make_tuple(Sequence<0, 2, 4, 6, 7, 8>{}, Sequence<1, 3, 5, 9>{}));
} }
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5_N4(const ZGridDesc_M_N& z_grid_desc_m_n) //
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M3 = mfma.num_groups_per_blk; // 4
constexpr auto M4 = mfma.num_input_blks; // 2
constexpr auto M5 = mfma.group_size; // 4
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, M3, M4, M5)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, NPerXdl / M5, M5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6, 7, 9>{}, Sequence<1, 3, 5, 8, 10>{}));
}
__device__ static auto GetGemm0WaveIdx() __device__ static auto GetGemm0WaveIdx()
{ {
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = get_thread_local_1d_id();
...@@ -399,6 +420,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -399,6 +420,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 = remove_cvref_t<decltype( using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(ZGridDesc_M_N{}))>; MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(ZGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5 = remove_cvref_t<decltype( // for shuffle
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5_N4(ZGridDesc_M_N{}))>;
// Q / K / V / dY // Q / K / V / dY
struct GemmBlockwiseCopy struct GemmBlockwiseCopy
{ {
...@@ -1235,7 +1259,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1235,7 +1259,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1, const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1, const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3& const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1, const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock& const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock, y_grid_desc_mblock_mperblock_oblock_operblock,
...@@ -1581,136 +1607,148 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1581,136 +1607,148 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// z vgpr copy to global // z vgpr copy to global
// //
// z matrix threadwise desc // z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 = constexpr auto z_tmp_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
n0, // NRepeat n0, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MGroupNum
m3, // NGroupNum m3, // MInputNum
m4, // NInputNum m4, // registerNum
n2)); // registerNum n2)); // NPerXdl
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4 =
make_naive_tensor_descriptor_packed(make_tuple(I1, //
I1, //
m0, //
n0, //
m1, //
n1, //
m2, // m0
m3, // m1
n2, // n0
I1, // m2
m4)); // n1
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short, unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(), z_tmp_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true> true>
z_tenor_buffer; z_tenor_buffer_tmp;
z_tenor_buffer.Clear(); z_tenor_buffer_tmp.Clear();
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short, unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4.GetElementSpaceSize(),
true> true>
z_tenor_buffer_tmp; z_tenor_buffer;
z_tenor_buffer_tmp.Clear(); z_tenor_buffer.Clear();
// z matrix global desc // z matrix global desc
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( // ignore = p_z_tmp_grid;
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
auto z_grid_buf_tmp = auto z_grid_buf_tmp =
make_dynamic_buffer<AddressSpaceEnum::Global>( // tmp buffer for shuffle make_dynamic_buffer<AddressSpaceEnum::Global>( // tmp buffer for shuffle
p_z_grid, p_z_grid,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize()); z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_tmp_thread_copy_vgpr_to_global = auto z_tmp_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ThreadwiseTensorSliceTransfer_v1r3<ushort, ushort,
ZDataType, ZDataType,
decltype( decltype(z_tmp_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), decltype(z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, Sequence<I1, // MBlockId
Sequence<I1, // MBlockId I1, // NBlockID
I1, // NBlockID m0, // MRepeat
m0, // MRepeat n0, // NRepeat
n0, // NRepeat m1, // MWaveId
m1, // MWaveId n1, // NWaveId
n1, // NWaveId m2, // MPerXdl
m2, // MPerXdl m3, // NGroupNum
m3, // NGroupNum m4, // NInputNum
m4, // NInputNum n2>,
n2>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, 9, // DstVectorDim
9, // DstVectorDim 1, // DstScalarPerVector
1, // DstScalarPerVector InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::Set, 1, // DstScalarStrideInVector
1, // DstScalarStrideInVector true>{z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, make_multi_index(0, // MBlockId
make_multi_index(0, // MBlockId block_work_idx_n, // NBlockId
block_work_idx_n, // NBlockId 0, // mrepeat
0, // mrepeat 0, // nrepeat
0, // nrepeat wave_id[I0], // MWaveId
wave_id[I0], // MWaveId wave_id[I1], // NWaveId
wave_id[I1], // NWaveId 0, // MPerXdl
0, // MPerXdl wave_m_n_id[I0], // group
wave_m_n_id[I0], // group 0, // NInputIndex
0, // NInputIndex wave_m_n_id[I1]),
wave_m_n_id[I1]), tensor_operation::element_wise::PassThrough{}};
tensor_operation::element_wise::PassThrough{}};
auto z_tmp_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
auto z_tmp_thread_copy_global_to_vgpr = ZDataType,
ThreadwiseTensorSliceTransfer_v2<ZDataType, ushort,
ushort, decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4),
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), Sequence<I1, I1, m0, n0, m1, n1, m2, m3, n2, I1, m4>,
Sequence<I1, I1, m0, n0, m1, n1, m2, m3, m4, n2>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, 10,
9, 1,
1, 1,
1, true /* ResetCoordAfterRun */>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4,
true /* ResetCoordAfterRun */>{ make_multi_index(0, // MBlockId
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, block_work_idx_n, // NBlockId
make_multi_index(0, // MBlockId 0, // mrepeat
block_work_idx_n, // NBlockId 0, // nrepeat
0, // mrepeat wave_id[I0], // MWaveId
0, // nrepeat wave_id[I1], // NWaveId
wave_id[I0], // MWaveId 0, // MPerXdl
wave_id[I1], // NWaveId wave_m_n_id[I0], // group
0, // MPerXdl int(wave_m_n_id[I1] / 4), // NInputIndex
wave_m_n_id[I0], // group wave_m_n_id[I1] % 4,
0, // NInputIndex 0)};
wave_m_n_id[I1])};
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
auto z_thread_copy_vgpr_to_global = ushort,
ThreadwiseTensorSliceTransfer_v1r3<ushort, ZDataType,
ZDataType, decltype(z_tmp_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype( decltype(z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), tensor_operation::element_wise::PassThrough,
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3), Sequence<I1, // MBlockId
tensor_operation::element_wise::PassThrough, I1, // NBlockID
Sequence<I1, // MBlockId m0, // MRepeat
I1, // NBlockID n0, // NRepeat
m0, // MRepeat m1, // MWaveId
n0, // NRepeat n1, // NWaveId
m1, // MWaveId m2, // MPerXdl
n1, // NWaveId m3, // NGroupNum
m2, // MPerXdl m4, // NInputNum
m3, // NGroupNum n2>,
m4, // NInputNum Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
n2>, 9, // DstVectorDim
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, 1, // DstScalarPerVector
9, // DstVectorDim InMemoryDataOperationEnum::Set,
1, // DstScalarPerVector 1, // DstScalarStrideInVector
InMemoryDataOperationEnum::Set, true>{z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
1, // DstScalarStrideInVector make_multi_index(0, // MBlockId
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, block_work_idx_n, // NBlockId
make_multi_index(0, // MBlockId 0, // mrepeat
block_work_idx_n, // NBlockId 0, // nrepeat
0, // mrepeat wave_id[I0], // MWaveId
0, // nrepeat wave_id[I1], // NWaveId
wave_id[I0], // MWaveId 0, // MPerXdl
wave_id[I1], // NWaveId wave_m_n_id[I0], // group
0, // MPerXdl 0, // NInputIndex
wave_m_n_id[I0], // group wave_m_n_id[I1]),
0, // NInputIndex tensor_operation::element_wise::PassThrough{}};
wave_m_n_id[I1]),
tensor_operation::element_wise::PassThrough{}};
// //
// set up Y dot dY // set up Y dot dY
...@@ -1981,6 +2019,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1981,6 +2019,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
// auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
// auto m_local =
// block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
// auto n_local =
// block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
// auto m_global = m_local + m_block_data_idx_on_grid;
// auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
//}
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) { static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin; auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local = auto m_local =
...@@ -2021,19 +2074,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2021,19 +2074,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{ {
// 8d thread_desc in thread scope // 8d thread_desc in thread scope
constexpr auto c_thread_lengths = constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// 8d block_desc in block scope // 8d block_desc in block scope
constexpr auto c_block_lengths = constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto M0 = c_block_lengths[I0]; constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1]; constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2]; constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3]; constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4]; constexpr auto M2 = c_block_lengths[I4];
constexpr auto M3 = c_block_lengths[I6]; constexpr auto M3 = c_block_lengths[I5];
constexpr auto M4 = c_block_lengths[I5]; constexpr auto M4 = c_block_lengths[I6];
constexpr auto N2 = c_block_lengths[I7]; constexpr auto N2 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear // works like multi-dimension static_for (static_ford), but provides both the linear
...@@ -2050,12 +2103,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2050,12 +2103,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin; auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
// }
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
// }
auto global_elem_id_raw = auto global_elem_id_raw =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
...@@ -2082,17 +2145,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2082,17 +2145,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
ph, global_elem_id, z_tenor_buffer_tmp, MRaw); ph, global_elem_id, z_tenor_buffer_tmp, MRaw);
z_tmp_thread_copy_vgpr_to_global.Run( z_tmp_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_tmp_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer_tmp, z_tenor_buffer_tmp,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf_tmp); z_grid_buf_tmp);
block_sync_lds();
z_tmp_thread_copy_global_to_vgpr.Run( z_tmp_thread_copy_global_to_vgpr.Run(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4,
z_grid_buf_tmp, z_grid_buf_tmp,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer); z_tenor_buffer);
blockwise_dropout.template ApplyDropout_v2<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropout_v2<decltype(s_slash_p_thread_buf),
...@@ -2100,11 +2165,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2100,11 +2165,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
true>(s_slash_p_thread_buf, true>(s_slash_p_thread_buf,
z_tenor_buffer); z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_thread_copy_vgpr_to_global.Run(z_tmp_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer, z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf); z_grid_buf);
block_sync_lds();
//// P_dropped //// P_dropped
// static_for<0, n0, 1>{}([&](auto i) { // static_for<0, n0, 1>{}([&](auto i) {
// blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), // blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
...@@ -2132,11 +2199,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2132,11 +2199,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{ {
// 8d thread_desc in thread scope // 8d thread_desc in thread scope
constexpr auto c_thread_lengths = constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// 8d block_desc in block scope // 8d block_desc in block scope
constexpr auto c_block_lengths = constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto M0 = c_block_lengths[I0]; constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1]; constexpr auto N0 = c_block_lengths[I1];
...@@ -2362,13 +2429,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2362,13 +2429,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
qgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow( qgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step M qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step M
z_tmp_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_tmp_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
z_tmp_thread_copy_global_to_vgpr.MoveSrcSliceWindow( z_tmp_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4,
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4, lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(1, 0, 0, 0, 0, 0)); make_multi_index(1, 0, 0, 0, 0, 0));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment