"vscode:/vscode.git/clone" did not exist on "3848606c7ed98c585b7a41397f99e1a873b17f61"
Commit 289e1196 authored by letaoqin's avatar letaoqin
Browse files

multiple M block

parent a33c100d
......@@ -273,8 +273,8 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 64;
ck::index_t N = 128;
ck::index_t M = 128;
ck::index_t N = 512;
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 1;
......@@ -468,7 +468,7 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); //dy[g0,g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0,g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
......
......@@ -122,11 +122,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
// D0
static constexpr auto D0M1 = Number<4>{};
static constexpr auto D0M0 = Number<MPerBlock / D0M1.value>{};
// static constexpr auto D0M1 = Number<MPerBlock / MPerXdl>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
......@@ -1157,21 +1152,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_block_bytes_end);
}
// D0
static constexpr auto D0M1 = Number<4>{};
static constexpr auto D0M0 = Number<MPerBlock>{} / D0M1;
__host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(const D0GridDesc_M_N& d0_grid_desc_m_n)
{
// const auto M = d0_grid_desc_m_n.GetLength(I0);
// const auto N = d0_grid_desc_m_n.GetLength(I1);
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
// const auto MBlock = M / MPerBlock;
// const auto NBlock = N / NPerBlock;
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 = transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(D0M0, D0M1)),
make_unmerge_transform(make_tuple(Number<NPerBlock>{}))),
make_tuple(make_unmerge_transform(make_tuple(MBlock, D0M0, D0M1)),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{}));
return d0_grid_desc_m0_n0_m1_m2_n1_m3;
}
......@@ -1184,8 +1183,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(make_tuple(D0M0, Number<NPerBlock>{}, D0M1),
make_tuple(Number<NPerBlock>{} * D0M1, D0M1, I1));
return make_naive_tensor_descriptor(make_tuple(I1, I1, D0M0, Number<NPerBlock>{}, D0M1),
make_tuple(Number<NPerBlock>{} * D0M1,
Number<NPerBlock>{} * D0M1,
Number<NPerBlock>{} * D0M1,
D0M1,
I1));
}
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3()
{
......@@ -1215,17 +1218,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<D0M0, NPerBlock, D0M1>, // BlockSliceLengths
Sequence<8, 32, 1>, // ThreadClusterLengths
Sequence<0, 2, 1>, // ThreadClusterArrangeOrder
Sequence<I1, I1, D0M0, NPerBlock, D0M1>, // BlockSliceLengths
Sequence<1, 1, 8, 32, 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 4, 3>, // ThreadClusterArrangeOrder
D0DataType, // SrcData
D0DataType, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 2, 1>, // SrcDimAccessOrder
Sequence<1, 0, 2>, // DstDimAccessOrder
1, // SrcVectorDim
2, // DstVectorDim
Sequence<0, 1, 2, 4, 3>, // SrcDimAccessOrder
Sequence<0, 1, 3, 2, 4>, // DstDimAccessOrder
3, // SrcVectorDim
4, // DstVectorDim
4, // SrcScalarPerVector
4, // DstScalarPerVector
1,
......@@ -1242,8 +1245,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
Sequence<1, 1, 8, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim
4, // SrcScalarPerVector
4>;
2, // SrcScalarPerVector
2>;
};
template <bool HasMainKBlockLoop,
......@@ -1730,17 +1733,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// wave_m_n_id[I0],
// wave_m_n_id[I1]);
// }
// D0
auto d0_block_copy_global_to_lds =
typename D0Loader::D0BlockwiseCopy(d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
ignore = d0_thread_copy_lds_to_vgpr;
//
// set up Y dot dY
//
......@@ -1833,6 +1825,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0
auto d0_block_copy_global_to_lds = typename D0Loader::D0BlockwiseCopy(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
do
{
auto m_block_data_idx_on_grid =
......@@ -2011,7 +2015,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_buf);
// d0_block_copy_global_to_lds.MoveSrcSliceWindow(
// d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
// d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(1, 0, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
d0_block_buf);
......@@ -2029,7 +2033,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(mr, i));
//if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// if(ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]) != 1.0f)
// {
// float tmp_lds =
......@@ -2049,9 +2053,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
});
//});
// d0_block_copy_global_to_lds.MoveSrcSliceWindow(
// d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(1, 0, -D0M1.value, 0, 0,
// 0));
d0_block_copy_global_to_lds.MoveSrcSliceWindow(d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, 0, 0, 0));
}
// P_i: = softmax(scalar * S_i:)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment