Commit 6916e3e4 authored by rocking's avatar rocking
Browse files

Clean the code

parent ad2f82ac
......@@ -125,11 +125,11 @@ __global__ void
const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid,
HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N& e_grid_desc_m_n,
const EHGridDesc_M_N& h_grid_desc_m_n,
const MeanVarCountGridDesc_M_N& mean_var_count_grid_desc_m_n,
const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n,
const EHGridDesc_M_N e_grid_desc_m_n,
const EHGridDesc_M_N h_grid_desc_m_n,
const MeanVarCountGridDesc_M_N mean_var_count_grid_desc_m_n,
const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N beta_grid_desc_n,
index_t blkgroup_size,
index_t num_mean_var_count_k_block_tile_iteration,
index_t num_xy_k_block_tile_iteration,
......@@ -507,9 +507,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_m_n_ =
DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
int s = mean_var_count_grid_desc_m_n_.GetElementSpaceSize();
printf("mean_var_count_grid_desc_m_n.GetElementSpaceSize() = %d\n", s);
hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw));
int gemm_welford_size = MRaw * gemm_nblock_;
......
......@@ -1080,12 +1080,6 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
count_thread_buf,
mean_var_count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf);
float mean = static_cast<float>(mean_thread_buf(I0));
float var = static_cast<float>(var_thread_buf(I0));
int count = count_thread_buf(I0);
if(i == 0 && get_thread_global_1d_id() == 0)
printf("1st kernel mean = %f, var = %f, count = %d\n", mean, var, count);
});
} // shuffle C + Ds + welford + write out
......
......@@ -63,8 +63,12 @@ struct GridwiseWelfordSecondHalfLayernorm2d
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_N{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
static constexpr auto thread_buffer_desc_m_1 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
using ThreadReduceSrcDesc_M_1 = decltype(thread_buffer_desc_m_1);
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
......@@ -124,201 +128,144 @@ struct GridwiseWelfordSecondHalfLayernorm2d
ignore = numEBlockTileIteration_N;
ignore = epsilon;
// float mean = static_cast<float>(p_in_welford_mean_grid[0]);
// float var = static_cast<float>(p_in_welford_var_grid[0]);
// int count = p_in_welford_count_grid[0];
// if(get_thread_global_1d_id() == 0)
// printf("kernel mean = %f, var = %f, count = %d\n", mean, var, count);
float mean = static_cast<float>(p_in_welford_mean_grid[0]);
if(get_thread_global_1d_id() == 0)
printf("mean = %f\n", mean);
int s = static_cast<int>(mean_var_count_grid_desc_m_n.GetElementSpaceSize());
if(get_thread_global_1d_id() == 0)
printf("mean_var_count_grid_desc_m_n.GetElementSpaceSize() = %d\n", s);
// using ThreadBufferLengths_1_1 = Sequence<1, 1>;
// constexpr auto thread_buffer_desc_1_1 =
// make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
// constexpr auto grid_desc_1_1 =
// make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
// const auto mean_grid = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, 1, true> mean_thread;
// float mean1 = (mean_grid.template Get<MeanDataType>(0, true));
// if(get_thread_global_1d_id() == 0)
// printf("global mean = %f\n", mean1);
// auto threadwise_mean_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<MeanDataType,
// ComputeDataType,
// decltype(mean_var_count_grid_desc_m_n),
// decltype(thread_buffer_desc_1_1),
// ThreadBufferLengths_1_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(mean_var_count_grid_desc_m_n,
// make_multi_index(0, 0));
// threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n,
// mean_grid,
// thread_buffer_desc_1_1,
// make_tuple(I0, I0),
// mean_thread);
// if(get_thread_global_1d_id() == 0)
// printf("threadwise mean = %f\n", mean_thread(Number<0>{}));
// // Thread/Block id
// const index_t thread_local_id = get_thread_local_1d_id();
// const index_t block_global_id = get_block_1d_id();
// const auto thread_cluster_idx =
// thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
// const auto thread_m_cluster_id = thread_cluster_idx[I0];
// const auto thread_n_cluster_id = thread_cluster_idx[I1];
// // step1: Merge mean and variance
// using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
// constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
// make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
// auto threadwise_mean_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<MeanDataType,
// ComputeDataType,
// MeanVarCountGridDesc_M_N,
// decltype(thread_buffer_desc_m_1),
// ThreadBufferLengths_M_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(mean_var_count_grid_desc_m_n,
// make_multi_index(0, 0));
// auto threadwise_var_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<VarDataType,
// ComputeDataType,
// MeanVarCountGridDesc_M_N,
// decltype(thread_buffer_desc_m_1),
// ThreadBufferLengths_M_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(
// mean_var_count_grid_desc_m_n,
// make_multi_index(block_global_id * M_BlockTileSize +
// thread_m_cluster_id * MThreadSliceSize,
// thread_n_cluster_id));
// auto threadwise_count_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<int32_t,
// int32_t,
// MeanVarCountGridDesc_M_N,
// decltype(thread_buffer_desc_m_1),
// ThreadBufferLengths_M_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(
// mean_var_count_grid_desc_m_n,
// make_multi_index(block_global_id * M_BlockTileSize +
// thread_m_cluster_id * MThreadSliceSize,
// thread_n_cluster_id));
// const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_var_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_count_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// in_welford_mean_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// in_welford_var_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
// in_welford_count_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// welford_mean_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// welford_var_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
// welford_count_thread_buf;
// constexpr auto mean_var_count_thread_copy_step_m_n =
// make_multi_index(0, NThreadClusterSize);
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
// welford_var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
// welford_count_thread_buf(I) = 0;
// });
// for(index_t reducedTiles = 0; reducedTiles < numMeanVarCountBlockTileIteration_N;
// ++reducedTiles)
// {
// threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n,
// welford_mean_global_val_buf,
// thread_buffer_desc_m_1,
// make_tuple(I0, I0),
// in_welford_mean_thread_buf);
// // threadwise_var_load_m_k.Run(mean_var_count_grid_desc_m_n,
// // welford_var_global_val_buf,
// // thread_buffer_desc_m_1,
// // make_tuple(I0, I0),
// // in_welford_var_thread_buf);
// // threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_n,
// // welford_count_global_val_buf,
// // thread_buffer_desc_m_1,
// // make_tuple(I0, I0),
// // in_welford_count_thread_buf);
// // ThreadwiseWelford::Run(in_welford_mean_thread_buf,
// // in_welford_var_thread_buf,
// // in_welford_count_thread_buf,
// // welford_mean_thread_buf,
// // welford_var_thread_buf,
// // welford_count_thread_buf);
// // threadwise_mean_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// // threadwise_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// // threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// if(get_thread_global_1d_id() == 0)
// printf("mean = %f, var = %f, count = %d\n",
// in_welford_mean_thread_buf(I),
// in_welford_var_thread_buf(I),
// in_welford_count_thread_buf(I));
// });
// }
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// if constexpr(I > 0)
// block_sync_lds();
// if(get_thread_global_1d_id() == 0)
// printf("count = %d\n", welford_count_thread_buf(I));
// BlockwiseWelford::Run(
// welford_mean_thread_buf(I), welford_var_thread_buf(I),
// welford_count_thread_buf(I));
// });
// Thread/Block id
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1];
// step1: Merge mean and variance
auto threadwise_mean_load_m_k =
ThreadwiseTensorSliceTransfer_v2<MeanDataType,
ComputeDataType,
MeanVarCountGridDesc_M_N,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_var_count_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
auto threadwise_var_load_m_k =
ThreadwiseTensorSliceTransfer_v2<VarDataType,
ComputeDataType,
MeanVarCountGridDesc_M_N,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_var_count_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
auto threadwise_count_load_m_k =
ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t,
MeanVarCountGridDesc_M_N,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_var_count_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_var_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
in_welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
constexpr auto mean_var_count_thread_copy_step_m_n =
make_multi_index(0, NThreadClusterSize);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
welford_count_thread_buf(I) = 0;
});
for(index_t reducedTiles = 0; reducedTiles < numMeanVarCountBlockTileIteration_N;
++reducedTiles)
{
threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n,
welford_mean_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_mean_thread_buf);
threadwise_var_load_m_k.Run(mean_var_count_grid_desc_m_n,
welford_var_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_var_thread_buf);
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_n,
welford_count_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_count_thread_buf);
ThreadwiseWelford::Run(in_welford_mean_thread_buf,
in_welford_var_thread_buf,
in_welford_count_thread_buf,
welford_mean_thread_buf,
welford_var_thread_buf,
welford_count_thread_buf);
threadwise_mean_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n);
threadwise_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
});
// step2: normalization
for(index_t reducedTiles = 0; reducedTiles < numEBlockTileIteration_N; ++reducedTiles)
{
// TODO
}
} // run
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment