Commit 6916e3e4 authored by rocking's avatar rocking
Browse files

Clean the code

parent ad2f82ac
...@@ -125,11 +125,11 @@ __global__ void ...@@ -125,11 +125,11 @@ __global__ void
const GammaDataType* __restrict__ p_gamma_grid, const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid, const BetaDataType* __restrict__ p_beta_grid,
HDataType* __restrict__ p_h_grid, HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N& e_grid_desc_m_n, const EHGridDesc_M_N e_grid_desc_m_n,
const EHGridDesc_M_N& h_grid_desc_m_n, const EHGridDesc_M_N h_grid_desc_m_n,
const MeanVarCountGridDesc_M_N& mean_var_count_grid_desc_m_n, const MeanVarCountGridDesc_M_N mean_var_count_grid_desc_m_n,
const GammaBetaGridDesc_N& gamma_grid_desc_n, const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n, const GammaBetaGridDesc_N beta_grid_desc_n,
index_t blkgroup_size, index_t blkgroup_size,
index_t num_mean_var_count_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
...@@ -507,9 +507,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -507,9 +507,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_m_n_ = mean_var_count_grid_desc_m_n_ =
DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_); DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
int s = mean_var_count_grid_desc_m_n_.GetElementSpaceSize();
printf("mean_var_count_grid_desc_m_n.GetElementSpaceSize() = %d\n", s);
hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw)); hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw));
int gemm_welford_size = MRaw * gemm_nblock_; int gemm_welford_size = MRaw * gemm_nblock_;
......
...@@ -1080,12 +1080,6 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1080,12 +1080,6 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
count_thread_buf, count_thread_buf,
mean_var_count_grid_desc_mblock_mperblock_nblock, mean_var_count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf); welford_count_grid_buf);
float mean = static_cast<float>(mean_thread_buf(I0));
float var = static_cast<float>(var_thread_buf(I0));
int count = count_thread_buf(I0);
if(i == 0 && get_thread_global_1d_id() == 0)
printf("1st kernel mean = %f, var = %f, count = %d\n", mean, var, count);
}); });
} // shuffle C + Ds + welford + write out } // shuffle C + Ds + welford + write out
......
...@@ -63,8 +63,12 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -63,8 +63,12 @@ struct GridwiseWelfordSecondHalfLayernorm2d
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_N{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_N{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype( using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
static constexpr auto thread_buffer_desc_m_1 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
using ThreadReduceSrcDesc_M_1 = decltype(thread_buffer_desc_m_1);
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
...@@ -124,201 +128,144 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -124,201 +128,144 @@ struct GridwiseWelfordSecondHalfLayernorm2d
ignore = numEBlockTileIteration_N; ignore = numEBlockTileIteration_N;
ignore = epsilon; ignore = epsilon;
// float mean = static_cast<float>(p_in_welford_mean_grid[0]); // Thread/Block id
// float var = static_cast<float>(p_in_welford_var_grid[0]); const index_t thread_local_id = get_thread_local_1d_id();
// int count = p_in_welford_count_grid[0]; const index_t block_global_id = get_block_1d_id();
// if(get_thread_global_1d_id() == 0) const auto thread_cluster_idx =
// printf("kernel mean = %f, var = %f, count = %d\n", mean, var, count); thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
float mean = static_cast<float>(p_in_welford_mean_grid[0]); const auto thread_n_cluster_id = thread_cluster_idx[I1];
if(get_thread_global_1d_id() == 0)
printf("mean = %f\n", mean); // step1: Merge mean and variance
auto threadwise_mean_load_m_k =
int s = static_cast<int>(mean_var_count_grid_desc_m_n.GetElementSpaceSize()); ThreadwiseTensorSliceTransfer_v2<MeanDataType,
if(get_thread_global_1d_id() == 0) ComputeDataType,
printf("mean_var_count_grid_desc_m_n.GetElementSpaceSize() = %d\n", s); MeanVarCountGridDesc_M_N,
decltype(thread_buffer_desc_m_1),
// using ThreadBufferLengths_1_1 = Sequence<1, 1>; ThreadBufferLengths_M_1,
// constexpr auto thread_buffer_desc_1_1 = Sequence<0, 1>,
// make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); 1,
// constexpr auto grid_desc_1_1 = 1,
// make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); 1,
true>(
// const auto mean_grid = make_dynamic_buffer<AddressSpaceEnum::Global>( mean_var_count_grid_desc_m_n,
// p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize()); make_multi_index(block_global_id * M_BlockTileSize +
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, 1, true> mean_thread; thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
// float mean1 = (mean_grid.template Get<MeanDataType>(0, true));
// if(get_thread_global_1d_id() == 0) auto threadwise_var_load_m_k =
// printf("global mean = %f\n", mean1); ThreadwiseTensorSliceTransfer_v2<VarDataType,
ComputeDataType,
// auto threadwise_mean_load_m_k = MeanVarCountGridDesc_M_N,
// ThreadwiseTensorSliceTransfer_v2<MeanDataType, decltype(thread_buffer_desc_m_1),
// ComputeDataType, ThreadBufferLengths_M_1,
// decltype(mean_var_count_grid_desc_m_n), Sequence<0, 1>,
// decltype(thread_buffer_desc_1_1), 1,
// ThreadBufferLengths_1_1, 1,
// Sequence<0, 1>, 1,
// 1, true>(
// 1, mean_var_count_grid_desc_m_n,
// 1, make_multi_index(block_global_id * M_BlockTileSize +
// true>(mean_var_count_grid_desc_m_n, thread_m_cluster_id * MThreadSliceSize,
// make_multi_index(0, 0)); thread_n_cluster_id));
// threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n, auto threadwise_count_load_m_k =
// mean_grid, ThreadwiseTensorSliceTransfer_v2<int32_t,
// thread_buffer_desc_1_1, int32_t,
// make_tuple(I0, I0), MeanVarCountGridDesc_M_N,
// mean_thread); decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
// if(get_thread_global_1d_id() == 0) Sequence<0, 1>,
// printf("threadwise mean = %f\n", mean_thread(Number<0>{})); 1,
1,
// // Thread/Block id 1,
// const index_t thread_local_id = get_thread_local_1d_id(); true>(
// const index_t block_global_id = get_block_1d_id(); mean_var_count_grid_desc_m_n,
// const auto thread_cluster_idx = make_multi_index(block_global_id * M_BlockTileSize +
// thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); thread_m_cluster_id * MThreadSliceSize,
// const auto thread_m_cluster_id = thread_cluster_idx[I0]; thread_n_cluster_id));
// const auto thread_n_cluster_id = thread_cluster_idx[I1];
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// // step1: Merge mean and variance p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
// constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// make_tuple(Number<MThreadSliceSize>{}, Number<1>{})); p_in_welford_var_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// auto threadwise_mean_load_m_k = const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// ThreadwiseTensorSliceTransfer_v2<MeanDataType, p_in_welford_count_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// ComputeDataType,
// MeanVarCountGridDesc_M_N, StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// decltype(thread_buffer_desc_m_1), in_welford_mean_thread_buf;
// ThreadBufferLengths_M_1, StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// Sequence<0, 1>, in_welford_var_thread_buf;
// 1, StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
// 1, in_welford_count_thread_buf;
// 1,
// true>(mean_var_count_grid_desc_m_n, StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// make_multi_index(0, 0)); welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// auto threadwise_var_load_m_k = welford_var_thread_buf;
// ThreadwiseTensorSliceTransfer_v2<VarDataType, StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
// ComputeDataType, welford_count_thread_buf;
// MeanVarCountGridDesc_M_N,
// decltype(thread_buffer_desc_m_1), constexpr auto mean_var_count_thread_copy_step_m_n =
// ThreadBufferLengths_M_1, make_multi_index(0, NThreadClusterSize);
// Sequence<0, 1>,
// 1, static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// 1, welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
// 1, welford_var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
// true>( welford_count_thread_buf(I) = 0;
// mean_var_count_grid_desc_m_n, });
// make_multi_index(block_global_id * M_BlockTileSize +
// thread_m_cluster_id * MThreadSliceSize, for(index_t reducedTiles = 0; reducedTiles < numMeanVarCountBlockTileIteration_N;
// thread_n_cluster_id)); ++reducedTiles)
{
// auto threadwise_count_load_m_k = threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n,
// ThreadwiseTensorSliceTransfer_v2<int32_t, welford_mean_global_val_buf,
// int32_t, thread_buffer_desc_m_1,
// MeanVarCountGridDesc_M_N, make_tuple(I0, I0),
// decltype(thread_buffer_desc_m_1), in_welford_mean_thread_buf);
// ThreadBufferLengths_M_1,
// Sequence<0, 1>, threadwise_var_load_m_k.Run(mean_var_count_grid_desc_m_n,
// 1, welford_var_global_val_buf,
// 1, thread_buffer_desc_m_1,
// 1, make_tuple(I0, I0),
// true>( in_welford_var_thread_buf);
// mean_var_count_grid_desc_m_n,
// make_multi_index(block_global_id * M_BlockTileSize + threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_n,
// thread_m_cluster_id * MThreadSliceSize, welford_count_global_val_buf,
// thread_n_cluster_id)); thread_buffer_desc_m_1,
make_tuple(I0, I0),
// const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( in_welford_count_thread_buf);
// p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
ThreadwiseWelford::Run(in_welford_mean_thread_buf,
// const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( in_welford_var_thread_buf,
// p_in_welford_var_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize()); in_welford_count_thread_buf,
welford_mean_thread_buf,
// const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( welford_var_thread_buf,
// p_in_welford_count_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize()); welford_count_thread_buf);
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true> threadwise_mean_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// in_welford_mean_thread_buf; mean_var_count_thread_copy_step_m_n);
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true> threadwise_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// in_welford_var_thread_buf; mean_var_count_thread_copy_step_m_n);
// StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true> threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// in_welford_count_thread_buf; mean_var_count_thread_copy_step_m_n);
}
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// welford_mean_thread_buf; static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true> if constexpr(I > 0)
// welford_var_thread_buf; block_sync_lds();
// StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
// welford_count_thread_buf; BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
// constexpr auto mean_var_count_thread_copy_step_m_n = });
// make_multi_index(0, NThreadClusterSize);
// step2: normalization
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) { for(index_t reducedTiles = 0; reducedTiles < numEBlockTileIteration_N; ++reducedTiles)
// welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f); {
// welford_var_thread_buf(I) = type_convert<ComputeDataType>(0.0f); // TODO
// welford_count_thread_buf(I) = 0; }
// });
// for(index_t reducedTiles = 0; reducedTiles < numMeanVarCountBlockTileIteration_N;
// ++reducedTiles)
// {
// threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n,
// welford_mean_global_val_buf,
// thread_buffer_desc_m_1,
// make_tuple(I0, I0),
// in_welford_mean_thread_buf);
// // threadwise_var_load_m_k.Run(mean_var_count_grid_desc_m_n,
// // welford_var_global_val_buf,
// // thread_buffer_desc_m_1,
// // make_tuple(I0, I0),
// // in_welford_var_thread_buf);
// // threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_n,
// // welford_count_global_val_buf,
// // thread_buffer_desc_m_1,
// // make_tuple(I0, I0),
// // in_welford_count_thread_buf);
// // ThreadwiseWelford::Run(in_welford_mean_thread_buf,
// // in_welford_var_thread_buf,
// // in_welford_count_thread_buf,
// // welford_mean_thread_buf,
// // welford_var_thread_buf,
// // welford_count_thread_buf);
// // threadwise_mean_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// // threadwise_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// // threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// if(get_thread_global_1d_id() == 0)
// printf("mean = %f, var = %f, count = %d\n",
// in_welford_mean_thread_buf(I),
// in_welford_var_thread_buf(I),
// in_welford_count_thread_buf(I));
// });
// }
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// if constexpr(I > 0)
// block_sync_lds();
// if(get_thread_global_1d_id() == 0)
// printf("count = %d\n", welford_count_thread_buf(I));
// BlockwiseWelford::Run(
// welford_mean_thread_buf(I), welford_var_thread_buf(I),
// welford_count_thread_buf(I));
// });
} // run } // run
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment