Commit 289e1196 authored by letaoqin's avatar letaoqin
Browse files

multiple M block

parent a33c100d
...@@ -273,8 +273,8 @@ int run(int argc, char* argv[]) ...@@ -273,8 +273,8 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 64; ck::index_t M = 128;
ck::index_t N = 128; ck::index_t N = 512;
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 1; ck::index_t G0 = 1;
...@@ -468,7 +468,7 @@ int run(int argc, char* argv[]) ...@@ -468,7 +468,7 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); //dy[g0,g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0,g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
......
...@@ -122,11 +122,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -122,11 +122,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{}; static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{}; static constexpr auto B1K1 = Number<B1K1Value>{};
// D0
static constexpr auto D0M1 = Number<4>{};
static constexpr auto D0M0 = Number<MPerBlock / D0M1.value>{};
// static constexpr auto D0M1 = Number<MPerBlock / MPerXdl>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2 static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time // get_random_8x16() generates 8 random numbers each time
...@@ -1157,21 +1152,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1157,21 +1152,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_block_bytes_end); c_block_bytes_end);
} }
// D0
static constexpr auto D0M1 = Number<4>{};
static constexpr auto D0M0 = Number<MPerBlock>{} / D0M1;
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(const D0GridDesc_M_N& d0_grid_desc_m_n) MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(const D0GridDesc_M_N& d0_grid_desc_m_n)
{ {
// const auto M = d0_grid_desc_m_n.GetLength(I0); const auto M = d0_grid_desc_m_n.GetLength(I0);
// const auto N = d0_grid_desc_m_n.GetLength(I1); const auto N = d0_grid_desc_m_n.GetLength(I1);
// const auto MBlock = M / MPerBlock; const auto MBlock = M / MPerBlock;
// const auto NBlock = N / NPerBlock; const auto NBlock = N / NPerBlock;
const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 = transform_tensor_descriptor( const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 = transform_tensor_descriptor(
d0_grid_desc_m_n, d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(D0M0, D0M1)), make_tuple(make_unmerge_transform(make_tuple(MBlock, D0M0, D0M1)),
make_unmerge_transform(make_tuple(Number<NPerBlock>{}))), make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{}));
return d0_grid_desc_m0_n0_m1_m2_n1_m3; return d0_grid_desc_m0_n0_m1_m2_n1_m3;
} }
...@@ -1184,8 +1183,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1184,8 +1183,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3() __host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3()
{ {
// B1 matrix in LDS memory, dst of blockwise copy // B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(make_tuple(D0M0, Number<NPerBlock>{}, D0M1), return make_naive_tensor_descriptor(make_tuple(I1, I1, D0M0, Number<NPerBlock>{}, D0M1),
make_tuple(Number<NPerBlock>{} * D0M1, D0M1, I1)); make_tuple(Number<NPerBlock>{} * D0M1,
Number<NPerBlock>{} * D0M1,
Number<NPerBlock>{} * D0M1,
D0M1,
I1));
} }
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3() __host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3()
{ {
...@@ -1215,17 +1218,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1215,17 +1218,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<D0M0, NPerBlock, D0M1>, // BlockSliceLengths Sequence<I1, I1, D0M0, NPerBlock, D0M1>, // BlockSliceLengths
Sequence<8, 32, 1>, // ThreadClusterLengths Sequence<1, 1, 8, 32, 1>, // ThreadClusterLengths
Sequence<0, 2, 1>, // ThreadClusterArrangeOrder Sequence<0, 1, 2, 4, 3>, // ThreadClusterArrangeOrder
D0DataType, // SrcData D0DataType, // SrcData
D0DataType, // DstData D0DataType, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_desc_m0_n0_m1_m2_n1_m3), // DstDesc decltype(d0_block_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 2, 1>, // SrcDimAccessOrder Sequence<0, 1, 2, 4, 3>, // SrcDimAccessOrder
Sequence<1, 0, 2>, // DstDimAccessOrder Sequence<0, 1, 3, 2, 4>, // DstDimAccessOrder
1, // SrcVectorDim 3, // SrcVectorDim
2, // DstVectorDim 4, // DstVectorDim
4, // SrcScalarPerVector 4, // SrcScalarPerVector
4, // DstScalarPerVector 4, // DstScalarPerVector
1, 1,
...@@ -1242,8 +1245,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1242,8 +1245,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
Sequence<1, 1, 8, 1, 4>, // SliceLengths Sequence<1, 1, 8, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
4, // SrcScalarPerVector 2, // SrcScalarPerVector
4>; 2>;
}; };
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
...@@ -1730,17 +1733,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1730,17 +1733,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// wave_m_n_id[I0], // wave_m_n_id[I0],
// wave_m_n_id[I1]); // wave_m_n_id[I1]);
// } // }
// D0
auto d0_block_copy_global_to_lds =
typename D0Loader::D0BlockwiseCopy(d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
ignore = d0_thread_copy_lds_to_vgpr;
// //
// set up Y dot dY // set up Y dot dY
// //
...@@ -1833,6 +1825,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1833,6 +1825,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// gemm0 M loop // gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1; index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0
auto d0_block_copy_global_to_lds = typename D0Loader::D0BlockwiseCopy(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
do do
{ {
auto m_block_data_idx_on_grid = auto m_block_data_idx_on_grid =
...@@ -2011,7 +2015,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2011,7 +2015,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_buf); d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_buf);
// d0_block_copy_global_to_lds.MoveSrcSliceWindow( // d0_block_copy_global_to_lds.MoveSrcSliceWindow(
// d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); // d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(1, 0, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3, d0_block_copy_global_to_lds.RunWrite(D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
d0_block_buf); d0_block_buf);
...@@ -2029,7 +2033,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2029,7 +2033,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// constexpr index_t c_offset = // constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(mr, i)); // c_thread_desc_.CalculateOffset(make_tuple(mr, i));
//if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0) // if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// if(ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]) != 1.0f) // if(ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]) != 1.0f)
// { // {
// float tmp_lds = // float tmp_lds =
...@@ -2049,9 +2053,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2049,9 +2053,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}); });
//}); //});
// d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(d0_grid_desc_m0_n0_m1_m2_n1_m3,
// d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(1, 0, -D0M1.value, 0, 0, make_multi_index(-1, 0, 0, 0, 0));
// 0));
} }
// P_i: = softmax(scalar * S_i:) // P_i: = softmax(scalar * S_i:)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment