"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "307c524bac267c99a93dea5a694acc11f4b1536f"
Commit ac6977f7 authored by Anthony Chang's avatar Anthony Chang
Browse files

tidy up

parent 2d91fd12
...@@ -336,32 +336,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle ...@@ -336,32 +336,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
} }
} }
// assuming packed tensor
static auto MakeGridDescriptor_M(index_t MRaw)
{
const auto grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return grid_desc_mraw;
}
}
static auto MakeGridDescriptor_N(index_t NRaw) static auto MakeGridDescriptor_N(index_t NRaw)
{ {
const auto grid_desc_nraw = make_naive_tensor_descriptor_packed(make_tuple(NRaw)); const auto grid_desc_nraw = make_naive_tensor_descriptor_packed(make_tuple(NRaw));
...@@ -604,7 +578,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle ...@@ -604,7 +578,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock, typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
ave_time = ave_time =
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
kernel, kernel,
......
...@@ -182,9 +182,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -182,9 +182,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed( make_naive_tensor_descriptor_packed(
make_tuple(I1, make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{}, // 1 * MWave * 32 Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1, I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{})); // 1 * NWave * 32 Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
} }
...@@ -296,24 +296,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -296,24 +296,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return c_grid_desc_mblock_mperblock_nblock_nperblock; return c_grid_desc_mblock_mperblock_nblock_nperblock;
} }
// for broadcasting bias, beta, gamma
// __host__ __device__ static constexpr auto
// MakeC0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const C0GridDescriptor_NBlock_NPerBlock& c0_grid_desc_nblock_nperblock)
// {
// const auto NBlock = c0_grid_desc_nblock_nperblock.GetLength(I0);
// const auto c0_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
// c0_grid_desc_nblock_nperblock,
// make_tuple(make_insert_transform(I1),
// make_insert_transform(I1),
// make_pass_through_transform(NBlock),
// make_pass_through_transform(NPerBlock)),
// make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// return c0_grid_desc_mblock_mperblock_nblock_nperblock;
// }
// for bias, beta, gamma // for bias, beta, gamma
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeC0GridDescriptor_NBlock_NPerBlock(const C0GridDesc_N& c0_grid_desc_n) MakeC0GridDescriptor_NBlock_NPerBlock(const C0GridDesc_N& c0_grid_desc_n)
...@@ -411,16 +393,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -411,16 +393,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c0_beta_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c0_beta_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_beta_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize()); p_c0_beta_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize());
// if (hipThreadIdx_x == 0 && hipBlockIdx_x == 0) c_grid_desc_mblock_mperblock_nblock_nperblock.Print();
/*
{TensorDescriptor,
transforms: {Embed, up_lengths_ {MultiIndex, size 2,256 128 }coefficients_ {MultiIndex, size 2,128 1 }}LowerDimensionIds:{size 1, 0 }UpperDimensionIds:{size 2, 1 2 }
transforms: {UnMerge, up_lengths_{MultiIndex, size 2,1 256 }up_lengths_scan_{MultiIndex, size 2,256 1 }}LowerDimensionIds:{size 1, 1 }UpperDimensionIds:{size 2, 3 4 }
transforms: {UnMerge, up_lengths_{MultiIndex, size 2,1 128 }up_lengths_scan_{MultiIndex, size 2,128 1 }}LowerDimensionIds:{size 1, 2 }UpperDimensionIds:{size 2, 5 6 }
}
{size 4, 3 4 5 6 }
*/
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
...@@ -891,7 +863,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -891,7 +863,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS // make sure it's safe to write to LDS
// __syncthreads();
block_sync_lds(); block_sync_lds();
// each thread write its data from VGPR to LDS // each thread write its data from VGPR to LDS
...@@ -901,9 +872,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -901,9 +872,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf); c_shuffle_block_buf);
// make sure it's safe to read from LDS
// debug::print_shared(c_shuffle_block_buf.p_data_, c_shuffle_block_buf.element_space_size_);
// __syncthreads();
block_sync_lds(); block_sync_lds();
// layernorm // layernorm
...@@ -924,34 +892,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -924,34 +892,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c0_thread_buf); c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}([&](auto i) { static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}([&](auto i) {
// auto thread_slice_desc = make_cluster_descriptor(
// Sequence<mreduce_per_thread, nreduce_per_thread>{});
// auto thread_slice_idx = thread_slice_desc.CalculateBottomIndex(make_multi_index(i));
// printf("tid %zd, access_id %d, im, in %d, %d, c0 = %f, c = %f\n",
// hipThreadIdx_x,
// access_id.value,
// thread_slice_idx[I0],
// thread_slice_idx[I1],
// c0_thread_buf(i),
// c_reduce_thread_buf(i));
c_reduce_thread_buf(i) += c0_thread_buf(i); c_reduce_thread_buf(i) += c0_thread_buf(i);
}); });
// static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
// static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
// constexpr auto offset =
// Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
// make_tuple(im, in))>{};
// c_reduce_thread_buf(offset) += c0_thread_buf(offset);
// // printf("tid %zd, access_id %d, im, in %d, %d, c0 = %f, c+c0 = %f\n",
// // hipThreadIdx_x,
// // access_id.value,
// // im.value,
// // in.value,
// // c0_thread_buf(offset),
// // c_reduce_thread_buf(offset));
// });
// });
using ThreadwiseReduceD0 = using ThreadwiseReduceD0 =
ThreadwiseReduction<FloatReduceAcc, ThreadwiseReduction<FloatReduceAcc,
...@@ -979,7 +921,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -979,7 +921,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// reduce squared sum in VGPR // reduce squared sum in VGPR
ThreadwiseReduceD1::Reduce(c_reduce_thread_buf, d1_thread_buf); ThreadwiseReduceD1::Reduce(c_reduce_thread_buf, d1_thread_buf);
// reduce across workgorup // reduce within workgroup
using BlockwiseReduce = PartitionedBlockwiseReduction<FloatReduceAcc, using BlockwiseReduce = PartitionedBlockwiseReduction<FloatReduceAcc,
BlockSize, BlockSize,
CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K
...@@ -992,17 +934,11 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -992,17 +934,11 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BlockwiseReduce::Reduce(d_reduce_work_buf, d0_thread_buf(i)); // blockwise reduced sum BlockwiseReduce::Reduce(d_reduce_work_buf, d0_thread_buf(i)); // blockwise reduced sum
block_sync_lds(); block_sync_lds();
BlockwiseReduce::Reduce(d_reduce_work_buf, d1_thread_buf(i)); // blockwise reduced squared sum BlockwiseReduce::Reduce(d_reduce_work_buf, d1_thread_buf(i)); // blockwise reduced squared sum
// printf("tid %zd, access_id %d, mreduce_idx %d, sum = %f, sq sum = %f\n",
// hipThreadIdx_x,
// access_id.value,
// i.value,
// d0_thread_buf(i),
// d1_thread_buf(i));
}); });
// normalize // normalize
const index_t NRaw = c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0].GetUpperLengths()[I1]; // TODO: proper handle const index_t NRaw = c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0].GetUpperLengths()[I1]; // TODO: proper handle
// if(hipThreadIdx_x == 0) printf("NRaw = %d\n", NRaw);
static_for<0, mreduce_per_thread, 1>{}([&](auto im) { static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) { static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto dst_offset = constexpr auto dst_offset =
...@@ -1022,14 +958,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -1022,14 +958,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
tensor_operation::element_wise::UnarySqrt<FloatReduceAcc, FloatReduceAcc>{}(divisor_sqrt, divisor); tensor_operation::element_wise::UnarySqrt<FloatReduceAcc, FloatReduceAcc>{}(divisor_sqrt, divisor);
c_reduce_thread_buf(dst_offset) = denom / divisor_sqrt; c_reduce_thread_buf(dst_offset) = denom / divisor_sqrt;
// printf("tid %zd, access_id %d, reduce_idx %d %d, avg_sum = %f, avg sq sum = %f, final = %f\n",
// hipThreadIdx_x,
// access_id.value,
// im.value,
// in.value,
// avg_sum,
// avg_squared_sum,
// c_reduce_thread_buf(dst_offset));
}); });
}); });
...@@ -1056,7 +984,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -1056,7 +984,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_reduce_thread_buf(i) += c0_thread_buf(i); // + beta c_reduce_thread_buf(i) += c0_thread_buf(i); // + beta
}); });
// __syncthreads();
block_sync_lds(); block_sync_lds();
c_reduce_thread_copy_vgpr_to_lds.Run(c_reduce_thread_desc_mperblock_nperblock, c_reduce_thread_copy_vgpr_to_lds.Run(c_reduce_thread_desc_mperblock_nperblock,
...@@ -1067,9 +994,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -1067,9 +994,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} // end layernorm } // end layernorm
// __syncthreads();
block_sync_lds(); block_sync_lds();
// debug::print_shared<32>(c_shuffle_block_buf.p_data_, c_shuffle_block_buf.element_space_size_);
// each block copy its data from LDS to global // each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run( c_shuffle_block_copy_lds_to_global.Run(
...@@ -1086,7 +1011,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -1086,7 +1011,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on C0 bias // move on C0
c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow( c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment