Commit 3b58c3ec authored by guangzlu's avatar guangzlu
Browse files

saved dropout random number in gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp

parent cc974f0f
...@@ -35,6 +35,7 @@ template <typename FloatAB, ...@@ -35,6 +35,7 @@ template <typename FloatAB,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename ZGridDesc_M_N,
typename LSEGridDesc_M, typename LSEGridDesc_M,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -116,6 +117,65 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle ...@@ -116,6 +117,65 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// C desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__host__ __device__ static constexpr auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M, const index_t N) ////=> for z use
{
constexpr auto mfma = MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(M, N)),
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(WaveSize / MPerXdl, MPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
...@@ -319,10 +379,13 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle ...@@ -319,10 +379,13 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>;
struct SharedMemTrait struct SharedMemTrait
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -367,6 +430,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle ...@@ -367,6 +430,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
unsigned short* __restrict__ p_z_grid,
FloatLSE* __restrict__ p_lse_grid, FloatLSE* __restrict__ p_lse_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -379,6 +443,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle ...@@ -379,6 +443,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
...@@ -782,6 +848,80 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle ...@@ -782,6 +848,80 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
// gemm1 K loop // gemm1 K loop
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
///////////////////=>z for dropout
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // registerNum
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
n4, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
///////////////////=>z for dropout
do do
{ {
auto n_block_data_idx_on_grid = auto n_block_data_idx_on_grid =
...@@ -876,9 +1016,34 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle ...@@ -876,9 +1016,34 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
if constexpr(IsDropout) // dropout if constexpr(IsDropout) // dropout
{ {
blockwise_dropout.ApplyDropout(acc_thread_buf, ph); // save z to global
if(p_z_grid)
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
decltype(z_tenor_buffer),
true>(
acc_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
}
else
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), true>(
acc_thread_buf, ph);
}
} }
//if constexpr(IsDropout) // dropout
//{
// blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
//}
// TODO: may convert to log domain // TODO: may convert to log domain
running_max_new = mathext::max(max, running_max); running_max_new = mathext::max(max, running_max);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment