Commit f82a220f authored by guangzlu's avatar guangzlu
Browse files

v4 pass

parent ff88ffa4
......@@ -766,7 +766,7 @@ int run(int argc, char* argv[])
auto argument = gemm.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()), // set to nullptr
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
......
......@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8.
#define DIM 32 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -710,17 +710,17 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 1000; // 512
ck::index_t N = 1000; // 512
ck::index_t M = 500; // 512
ck::index_t N = 500; // 512
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4; // 54
ck::index_t G1 = 6; // 16
ck::index_t G0 = 2; // 54
ck::index_t G1 = 1; // 16
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.0;
float p_drop = 0.1;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......@@ -944,7 +944,7 @@ int run(int argc, char* argv[])
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr),
static_cast<ZDataType*>(z_fwd_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
......@@ -998,7 +998,7 @@ int run(int argc, char* argv[])
auto argument_bwd = gemm_bwd.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr
static_cast<ZDataType*>(z_bwd_device_buf.GetDeviceBuffer()), // set to nullptr
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
......@@ -1399,20 +1399,20 @@ int run(int argc, char* argv[])
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
qgrad_gs_ms_ks_host_result.mData,
"error",
1e-2,
1e-2);
1e-3,
1e-3);
std::cout << "Checking kgrad:\n";
pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
kgrad_gs_ns_ks_host_result.mData,
"error",
1e-2,
1e-2);
1e-3,
1e-3);
std::cout << "Checking vgrad:\n";
pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
vgrad_gs_os_ns_host_result.mData,
"error",
1e-2,
1e-2);
1e-3,
1e-3);
}
return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
......
......@@ -145,14 +145,14 @@ struct BlockwiseDropout
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
}
ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
for(int j = 0; j < 4; j++)
{
tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
}
}
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
// }
//}
block_sync_lds();
......@@ -162,7 +162,7 @@ struct BlockwiseDropout
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp_id[tmp_index];
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
......@@ -208,17 +208,17 @@ struct BlockwiseDropout
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
}
ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
for(int j = 0; j < 4; j++)
{
tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
}
}
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
// }
//}
block_sync_lds();
......@@ -226,7 +226,7 @@ struct BlockwiseDropout
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
z_thread_buf(offset) = tmp_id[tmp_index];
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
......
......@@ -22,6 +22,7 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp"
namespace ck {
namespace tensor_operation {
......@@ -40,6 +41,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5,
typename B1GridDesc_BK0_N_BK1,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M,
......@@ -73,6 +75,8 @@ __global__ void
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -138,6 +142,7 @@ __global__ void
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
......@@ -173,6 +178,7 @@ __global__ void
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
......@@ -828,6 +834,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5_N4(
z_grid_desc_m_n_);
// tmp z tensor for shuffle
// Tensor<ZDataType> z_tmp_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
// DeviceMem z_tmp_device_buf(sizeof(ZDataType) *
// z_tmp_gs_ms_ns.mDesc.GetElementSpaceSize());
// z_tmp_device_buf.ToDevice(z_tmp_gs_ms_ns.mData.data());
// p_z_tmp_grid_ = reinterpret_cast<ZDataType*>(z_tmp_device_buf.GetDeviceBuffer());
// Print();
}
......@@ -859,7 +877,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// pointers
const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_;
// ZDataType* p_z_tmp_grid_;
ZDataType* p_z_grid_;
const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_;
......@@ -890,6 +911,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
......@@ -952,6 +976,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M,
......@@ -986,6 +1011,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_,
......
......@@ -122,6 +122,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
......@@ -132,6 +133,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_tuple(Sequence<0, 2, 4, 6, 7, 8>{}, Sequence<1, 3, 5, 9>{}));
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5_N4(const ZGridDesc_M_N& z_grid_desc_m_n) //
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M3 = mfma.num_groups_per_blk; // 4
constexpr auto M4 = mfma.num_input_blks; // 2
constexpr auto M5 = mfma.group_size; // 4
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, M3, M4, M5)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, NPerXdl / M5, M5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6, 7, 9>{}, Sequence<1, 3, 5, 8, 10>{}));
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
......@@ -399,6 +420,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(ZGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5 = remove_cvref_t<decltype( // for shuffle
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5_N4(ZGridDesc_M_N{}))>;
// Q / K / V / dY
struct GemmBlockwiseCopy
{
......@@ -1235,7 +1259,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
......@@ -1581,50 +1607,63 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
constexpr auto z_tmp_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
m3, // NGroupNum
m4, // NInputNum
n2)); // registerNum
m2, // MGroupNum
m3, // MInputNum
m4, // registerNum
n2)); // NPerXdl
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4 =
make_naive_tensor_descriptor_packed(make_tuple(I1, //
I1, //
m0, //
n0, //
m1, //
n1, //
m2, // m0
m3, // m1
n2, // n0
I1, // m2
m4)); // n1
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
z_tmp_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
z_tenor_buffer_tmp;
z_tenor_buffer_tmp.Clear();
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4.GetElementSpaceSize(),
true>
z_tenor_buffer_tmp;
z_tenor_buffer_tmp.Clear();
z_tenor_buffer;
z_tenor_buffer.Clear();
// z matrix global desc
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
// ignore = p_z_tmp_grid;
auto z_grid_buf_tmp =
make_dynamic_buffer<AddressSpaceEnum::Global>( // tmp buffer for shuffle
p_z_grid,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_tmp_thread_copy_vgpr_to_global =
ThreadwiseTensorSliceTransfer_v1r3<ushort,
auto z_tmp_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ZDataType,
decltype(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_tmp_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
......@@ -1641,7 +1680,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
true>{z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(0, // MBlockId
block_work_idx_n, // NBlockId
0, // mrepeat
......@@ -1654,18 +1693,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
wave_m_n_id[I1]),
tensor_operation::element_wise::PassThrough{}};
auto z_tmp_thread_copy_global_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<ZDataType,
auto z_tmp_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ZDataType,
ushort,
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
Sequence<I1, I1, m0, n0, m1, n1, m2, m3, m4, n2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4),
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4),
Sequence<I1, I1, m0, n0, m1, n1, m2, m3, n2, I1, m4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10>,
10,
1,
1,
true /* ResetCoordAfterRun */>{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
true /* ResetCoordAfterRun */>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4,
make_multi_index(0, // MBlockId
block_work_idx_n, // NBlockId
0, // mrepeat
......@@ -1674,15 +1712,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
wave_id[I1], // NWaveId
0, // MPerXdl
wave_m_n_id[I0], // group
0, // NInputIndex
wave_m_n_id[I1])};
int(wave_m_n_id[I1] / 4), // NInputIndex
wave_m_n_id[I1] % 4,
0)};
auto z_thread_copy_vgpr_to_global =
ThreadwiseTensorSliceTransfer_v1r3<ushort,
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ZDataType,
decltype(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_tmp_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
......@@ -1699,7 +1737,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
true>{z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(0, // MBlockId
block_work_idx_n, // NBlockId
0, // mrepeat
......@@ -1981,6 +2019,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
// auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
// auto m_local =
// block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
// auto n_local =
// block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
// auto m_global = m_local + m_block_data_idx_on_grid;
// auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
//}
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local =
......@@ -2021,19 +2074,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto M3 = c_block_lengths[I6];
constexpr auto M4 = c_block_lengths[I5];
constexpr auto M3 = c_block_lengths[I5];
constexpr auto M4 = c_block_lengths[I6];
constexpr auto N2 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
......@@ -2050,12 +2103,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
// }
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
// }
auto global_elem_id_raw =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
......@@ -2082,17 +2145,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
ph, global_elem_id, z_tenor_buffer_tmp, MRaw);
z_tmp_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_tmp_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer_tmp,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf_tmp);
block_sync_lds();
z_tmp_thread_copy_global_to_vgpr.Run(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4,
z_grid_buf_tmp,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer);
blockwise_dropout.template ApplyDropout_v2<decltype(s_slash_p_thread_buf),
......@@ -2100,11 +2165,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
true>(s_slash_p_thread_buf,
z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_thread_copy_vgpr_to_global.Run(z_tmp_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
block_sync_lds();
//// P_dropped
// static_for<0, n0, 1>{}([&](auto i) {
// blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
......@@ -2132,11 +2199,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
......@@ -2362,13 +2429,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
qgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step M
z_tmp_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
z_tmp_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4,
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(1, 0, 0, 0, 0, 0));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment