Commit ad2f82ac authored by rocking's avatar rocking
Browse files

Sync code, prepare to test on MI200

parent 1d7290fb
...@@ -284,7 +284,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -284,7 +284,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
} }
template <typename LayOut> template <typename LayOut>
static auto MakeGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t Stride) static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t Stride)
{ {
const auto grid_desc_mraw_nraw = [&]() { const auto grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, LayOut>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, LayOut>::value)
...@@ -308,11 +308,19 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -308,11 +308,19 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]); return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
static auto MakeMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock)
{
const auto grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, NBlock));
// TODO - padding according to MNperBlock of Gemm and Layernorm
return grid_desc_m_n;
}
static auto MakeDescriptor_M(index_t MRaw) static auto MakeDescriptor_M(index_t MRaw)
{ {
const auto grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); const auto grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
...@@ -366,9 +374,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -366,9 +374,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using MeanVarCountGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1)); using MeanVarCountGridDesc_M_N = decltype(MakeMeanVarCountGridDescriptor_M_NBlock(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1)); using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using EHGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1)); using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1));
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
...@@ -479,11 +487,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -479,11 +487,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
mean_var_count_grid_desc_m_n_{}, mean_var_count_grid_desc_m_n_{},
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)}, gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)}, beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)}, h_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
...@@ -497,7 +505,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -497,7 +505,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
epsilon_{epsilon} epsilon_{epsilon}
{ {
mean_var_count_grid_desc_m_n_ = mean_var_count_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, gemm_nblock_, gemm_nblock_); DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
int s = mean_var_count_grid_desc_m_n_.GetElementSpaceSize();
printf("mean_var_count_grid_desc_m_n.GetElementSpaceSize() = %d\n", s);
hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw)); hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw));
...@@ -518,7 +529,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -518,7 +529,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
// D desc // D desc
ds_grid_desc_m_n_(i) = ds_grid_desc_m_n_(i) =
DeviceOp::MakeGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]); DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]);
}); });
// populate desc for Ds/E/F/G // populate desc for Ds/E/F/G
...@@ -526,7 +537,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -526,7 +537,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_, b_grid_desc_n_k_,
ds_grid_desc_m_n_, ds_grid_desc_m_n_,
e_grid_desc_m_n_, e_grid_desc_m_n_,
mean_var_count_grid_desc_m_n_,
block_2_etile_map_)) block_2_etile_map_))
{ {
ds_grid_desc_mblock_mperblock_nblock_nperblock_ = ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
...@@ -612,7 +622,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -612,7 +622,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.mean_var_count_grid_desc_m_n_,
arg.block_2_etile_map_)) arg.block_2_etile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting"); throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
...@@ -694,7 +703,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -694,7 +703,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
LayernormThreadClusterSize_M_N::At(I1)) / LayernormThreadClusterSize_M_N::At(I1)) /
LayernormThreadClusterSize_M_N::At(I1); LayernormThreadClusterSize_M_N::At(I1);
index_t numXBlockTileIteration_N = index_t numEBlockTileIteration_N =
math::integer_least_multiple(N, LayernormBlockTileSize_M_N::At(I1)) / math::integer_least_multiple(N, LayernormBlockTileSize_M_N::At(I1)) /
LayernormBlockTileSize_M_N::At(I1); LayernormBlockTileSize_M_N::At(I1);
...@@ -717,7 +726,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -717,7 +726,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.beta_grid_desc_n_, arg.beta_grid_desc_n_,
arg.gemm_nblock_, arg.gemm_nblock_,
numMeanVarCountBlockTileIteration_N, numMeanVarCountBlockTileIteration_N,
numXBlockTileIteration_N, numEBlockTileIteration_N,
arg.epsilon_); arg.epsilon_);
return avg_time; return avg_time;
......
...@@ -269,13 +269,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -269,13 +269,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap> template <typename Block2ETileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, const BGridDesc_N_K& b_grid_desc_n_k,
const BGridDesc_N_K& b_grid_desc_n_k, const DsGridDesc_M_N& ds_grid_desc_m_n,
const DsGridDesc_M_N& ds_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n, const Block2ETileMap& block_2_etile_map)
const MeanVarCountGridDesc_M_N& mean_var_count_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
...@@ -286,9 +284,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -286,9 +284,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
const auto K = a_grid_desc_m_k.GetLength(I1); const auto K = a_grid_desc_m_k.GetLength(I1);
// check consistency of desc // check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
M == mean_var_count_grid_desc_m_n.GetLength(I0) &&
N / NPerBlock == mean_var_count_grid_desc_m_n.GetLength(I1)))
{ {
return false; return false;
} }
...@@ -997,6 +993,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -997,6 +993,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) { static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
block_sync_lds(); block_sync_lds();
count_thread_buf = threadwise_welfords(i).cur_count_;
BlockwiseWelford::Run( BlockwiseWelford::Run(
mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j)); mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
}); });
...@@ -1083,6 +1080,12 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1083,6 +1080,12 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
count_thread_buf, count_thread_buf,
mean_var_count_grid_desc_mblock_mperblock_nblock, mean_var_count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf); welford_count_grid_buf);
float mean = static_cast<float>(mean_thread_buf(I0));
float var = static_cast<float>(var_thread_buf(I0));
int count = count_thread_buf(I0);
if(i == 0 && get_thread_global_1d_id() == 0)
printf("1st kernel mean = %f, var = %f, count = %d\n", mean, var, count);
}); });
} // shuffle C + Ds + welford + write out } // shuffle C + Ds + welford + write out
......
...@@ -34,15 +34,23 @@ template <typename EDataType, ...@@ -34,15 +34,23 @@ template <typename EDataType,
index_t NThreadClusterSize, index_t NThreadClusterSize,
index_t MThreadSliceSize, index_t MThreadSliceSize,
index_t NThreadSliceSize, index_t NThreadSliceSize,
index_t ESrcYDstVectorDim, index_t ESrcHDstVectorDim,
index_t ESrcVectorSize, index_t ESrcVectorSize,
index_t YDstVectorSize, index_t HDstVectorSize,
index_t GammaSrcVectorSize, index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t MeanVarSrcDstVectorSize> index_t MeanVarSrcDstVectorSize>
struct GridwiseWelfordSecondHalfLayernorm2d struct GridwiseWelfordSecondHalfLayernorm2d
{ {
static constexpr bool reorder_thread_cluster = (ESrcYDstVectorDim == 0); static_assert((ESrcHDstVectorDim == 0 && MThreadSliceSize % ESrcVectorSize == 0) ||
(ESrcHDstVectorDim == 1 && NThreadSliceSize % ESrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((ESrcHDstVectorDim == 0 && MThreadSliceSize % HDstVectorSize == 0) ||
(ESrcHDstVectorDim == 1 && NThreadSliceSize % HDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (ESrcHDstVectorDim == 0);
using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>; using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>;
...@@ -73,8 +81,14 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -73,8 +81,14 @@ struct GridwiseWelfordSecondHalfLayernorm2d
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize; static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
static constexpr index_t N_BlockTileStepSize = NThreadClusterSize * ESrcVectorSize;
static constexpr auto EThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
static constexpr auto GammaThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
static constexpr auto BetaThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
static constexpr auto HThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
__device__ static void Run(const EDataType* __restrict__ p_e_grid, __device__ static void Run(const EDataType* __restrict__ p_e_grid,
const MeanDataType* __restrict__ p_in_welford_mean_grid, const MeanDataType* __restrict__ p_in_welford_mean_grid,
...@@ -89,8 +103,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -89,8 +103,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
const GammaBetaGridDesc_N& gamma_grid_desc_n, const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n, const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t gemm_nblock_, index_t gemm_nblock_,
index_t num_mean_var_count_k_block_tile_iteration, index_t numMeanVarCountBlockTileIteration_N,
index_t num_xy_k_block_tile_iteration, index_t numEBlockTileIteration_N,
ComputeDataType epsilon) ComputeDataType epsilon)
{ {
ignore = p_e_grid; ignore = p_e_grid;
...@@ -106,10 +120,206 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -106,10 +120,206 @@ struct GridwiseWelfordSecondHalfLayernorm2d
ignore = gamma_grid_desc_n; ignore = gamma_grid_desc_n;
ignore = beta_grid_desc_n; ignore = beta_grid_desc_n;
ignore = gemm_nblock_; ignore = gemm_nblock_;
ignore = num_mean_var_count_k_block_tile_iteration; ignore = numMeanVarCountBlockTileIteration_N;
ignore = num_xy_k_block_tile_iteration; ignore = numEBlockTileIteration_N;
ignore = epsilon; ignore = epsilon;
// float mean = static_cast<float>(p_in_welford_mean_grid[0]);
// float var = static_cast<float>(p_in_welford_var_grid[0]);
// int count = p_in_welford_count_grid[0];
// if(get_thread_global_1d_id() == 0)
// printf("kernel mean = %f, var = %f, count = %d\n", mean, var, count);
float mean = static_cast<float>(p_in_welford_mean_grid[0]);
if(get_thread_global_1d_id() == 0)
printf("mean = %f\n", mean);
int s = static_cast<int>(mean_var_count_grid_desc_m_n.GetElementSpaceSize());
if(get_thread_global_1d_id() == 0)
printf("mean_var_count_grid_desc_m_n.GetElementSpaceSize() = %d\n", s);
// using ThreadBufferLengths_1_1 = Sequence<1, 1>;
// constexpr auto thread_buffer_desc_1_1 =
// make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
// constexpr auto grid_desc_1_1 =
// make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
// const auto mean_grid = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, 1, true> mean_thread;
// float mean1 = (mean_grid.template Get<MeanDataType>(0, true));
// if(get_thread_global_1d_id() == 0)
// printf("global mean = %f\n", mean1);
// auto threadwise_mean_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<MeanDataType,
// ComputeDataType,
// decltype(mean_var_count_grid_desc_m_n),
// decltype(thread_buffer_desc_1_1),
// ThreadBufferLengths_1_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(mean_var_count_grid_desc_m_n,
// make_multi_index(0, 0));
// threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n,
// mean_grid,
// thread_buffer_desc_1_1,
// make_tuple(I0, I0),
// mean_thread);
// if(get_thread_global_1d_id() == 0)
// printf("threadwise mean = %f\n", mean_thread(Number<0>{}));
// // Thread/Block id
// const index_t thread_local_id = get_thread_local_1d_id();
// const index_t block_global_id = get_block_1d_id();
// const auto thread_cluster_idx =
// thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
// const auto thread_m_cluster_id = thread_cluster_idx[I0];
// const auto thread_n_cluster_id = thread_cluster_idx[I1];
// // step1: Merge mean and variance
// using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
// constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
// make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
// auto threadwise_mean_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<MeanDataType,
// ComputeDataType,
// MeanVarCountGridDesc_M_N,
// decltype(thread_buffer_desc_m_1),
// ThreadBufferLengths_M_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(mean_var_count_grid_desc_m_n,
// make_multi_index(0, 0));
// auto threadwise_var_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<VarDataType,
// ComputeDataType,
// MeanVarCountGridDesc_M_N,
// decltype(thread_buffer_desc_m_1),
// ThreadBufferLengths_M_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(
// mean_var_count_grid_desc_m_n,
// make_multi_index(block_global_id * M_BlockTileSize +
// thread_m_cluster_id * MThreadSliceSize,
// thread_n_cluster_id));
// auto threadwise_count_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<int32_t,
// int32_t,
// MeanVarCountGridDesc_M_N,
// decltype(thread_buffer_desc_m_1),
// ThreadBufferLengths_M_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(
// mean_var_count_grid_desc_m_n,
// make_multi_index(block_global_id * M_BlockTileSize +
// thread_m_cluster_id * MThreadSliceSize,
// thread_n_cluster_id));
// const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_var_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_count_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// in_welford_mean_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// in_welford_var_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
// in_welford_count_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// welford_mean_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// welford_var_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
// welford_count_thread_buf;
// constexpr auto mean_var_count_thread_copy_step_m_n =
// make_multi_index(0, NThreadClusterSize);
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
// welford_var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
// welford_count_thread_buf(I) = 0;
// });
// for(index_t reducedTiles = 0; reducedTiles < numMeanVarCountBlockTileIteration_N;
// ++reducedTiles)
// {
// threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n,
// welford_mean_global_val_buf,
// thread_buffer_desc_m_1,
// make_tuple(I0, I0),
// in_welford_mean_thread_buf);
// // threadwise_var_load_m_k.Run(mean_var_count_grid_desc_m_n,
// // welford_var_global_val_buf,
// // thread_buffer_desc_m_1,
// // make_tuple(I0, I0),
// // in_welford_var_thread_buf);
// // threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_n,
// // welford_count_global_val_buf,
// // thread_buffer_desc_m_1,
// // make_tuple(I0, I0),
// // in_welford_count_thread_buf);
// // ThreadwiseWelford::Run(in_welford_mean_thread_buf,
// // in_welford_var_thread_buf,
// // in_welford_count_thread_buf,
// // welford_mean_thread_buf,
// // welford_var_thread_buf,
// // welford_count_thread_buf);
// // threadwise_mean_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// // threadwise_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// // threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// if(get_thread_global_1d_id() == 0)
// printf("mean = %f, var = %f, count = %d\n",
// in_welford_mean_thread_buf(I),
// in_welford_var_thread_buf(I),
// in_welford_count_thread_buf(I));
// });
// }
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// if constexpr(I > 0)
// block_sync_lds();
// if(get_thread_global_1d_id() == 0)
// printf("count = %d\n", welford_count_thread_buf(I));
// BlockwiseWelford::Run(
// welford_mean_thread_buf(I), welford_var_thread_buf(I),
// welford_count_thread_buf(I));
// });
} // run } // run
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment