Commit f8c44314 authored by Anthony Chang's avatar Anthony Chang
Browse files

reflect reduction API's recent change

parent 7e610626
...@@ -819,12 +819,11 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -819,12 +819,11 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
3, 3,
CReduceThreadCopySrcDstScalarPerVector_NPerBlock, CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
1, 1,
true>( true>(c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_grid_desc_mblock_mperblock_nblock_nperblock, make_multi_index(block_work_idx[I0],
make_multi_index(block_work_idx[I0], c_reduce_thread_data_idx_begin[I0],
c_reduce_thread_data_idx_begin[I0], block_work_idx[I1],
block_work_idx[I1], c_reduce_thread_data_idx_begin[I1]));
c_reduce_thread_data_idx_begin[I1]));
// Note: c0_add is of same layout as c so we don't declare new c0_add_desc here // Note: c0_add is of same layout as c so we don't declare new c0_add_desc here
auto c0_add_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< auto c0_add_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
...@@ -837,12 +836,11 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -837,12 +836,11 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
3, 3,
CReduceThreadCopySrcDstScalarPerVector_NPerBlock, CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
1, 1,
true>( true>(c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock, make_multi_index(block_work_idx[I0],
make_multi_index(block_work_idx[I0], c_reduce_thread_data_idx_begin[I0],
c_reduce_thread_data_idx_begin[I0], block_work_idx[I1],
block_work_idx[I1], c_reduce_thread_data_idx_begin[I1]));
c_reduce_thread_data_idx_begin[I1]));
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
...@@ -885,10 +883,10 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -885,10 +883,10 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// load from LDS and global, add bias // load from LDS and global, add bias
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf, c_shuffle_block_buf,
c_reduce_thread_desc_mperblock_nperblock, c_reduce_thread_desc_mperblock_nperblock,
make_tuple(I0, I0), make_tuple(I0, I0),
c_reduce_thread_buf); c_reduce_thread_buf);
c0_thread_copy_global_to_vgpr.Run( c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock, c0_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -900,8 +898,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -900,8 +898,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) { [&](auto i) {
FloatReduceAcc out; FloatReduceAcc out;
acc_element_op(out, c_reduce_thread_buf(i) + acc_element_op(out,
static_cast<FloatReduceAcc>(c0_thread_buf(i))); c_reduce_thread_buf(i) +
static_cast<FloatReduceAcc>(c0_thread_buf(i)));
c_reduce_thread_buf(i) = out; // acc_element_op(acc + bias) c_reduce_thread_buf(i) = out; // acc_element_op(acc + bias)
}); });
...@@ -933,8 +932,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -933,8 +932,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
reduce::SquaredAdd<FloatReduceAcc>, reduce::SquaredAdd<FloatReduceAcc>,
false>; false>;
const auto d0_zeroVal = ThreadwiseReduceD0::Op::GetReductionZeroVal(); const auto d0_zeroVal = ThreadwiseReduceD0::Op::GetIdentityValue();
const auto d1_zeroVal = ThreadwiseReduceD1::Op::GetReductionZeroVal(); const auto d1_zeroVal = ThreadwiseReduceD1::Op::GetIdentityValue();
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
[&](auto i) { d0_thread_buf(i) = d0_zeroVal; }); [&](auto i) { d0_thread_buf(i) = d0_zeroVal; });
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
......
...@@ -76,7 +76,7 @@ struct SquaredAdd ...@@ -76,7 +76,7 @@ struct SquaredAdd
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); }; __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b * b; } __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b * b; }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment