Commit f8c44314 authored by Anthony Chang's avatar Anthony Chang
Browse files

reflect reduction API's recent change

parent 7e610626
...@@ -819,8 +819,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -819,8 +819,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
3, 3,
CReduceThreadCopySrcDstScalarPerVector_NPerBlock, CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
1, 1,
true>( true>(c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], make_multi_index(block_work_idx[I0],
c_reduce_thread_data_idx_begin[I0], c_reduce_thread_data_idx_begin[I0],
block_work_idx[I1], block_work_idx[I1],
...@@ -837,8 +836,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -837,8 +836,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
3, 3,
CReduceThreadCopySrcDstScalarPerVector_NPerBlock, CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
1, 1,
true>( true>(c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], make_multi_index(block_work_idx[I0],
c_reduce_thread_data_idx_begin[I0], c_reduce_thread_data_idx_begin[I0],
block_work_idx[I1], block_work_idx[I1],
...@@ -900,7 +898,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -900,7 +898,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) { [&](auto i) {
FloatReduceAcc out; FloatReduceAcc out;
acc_element_op(out, c_reduce_thread_buf(i) + acc_element_op(out,
c_reduce_thread_buf(i) +
static_cast<FloatReduceAcc>(c0_thread_buf(i))); static_cast<FloatReduceAcc>(c0_thread_buf(i)));
c_reduce_thread_buf(i) = out; // acc_element_op(acc + bias) c_reduce_thread_buf(i) = out; // acc_element_op(acc + bias)
}); });
...@@ -933,8 +932,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -933,8 +932,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
reduce::SquaredAdd<FloatReduceAcc>, reduce::SquaredAdd<FloatReduceAcc>,
false>; false>;
const auto d0_zeroVal = ThreadwiseReduceD0::Op::GetReductionZeroVal(); const auto d0_zeroVal = ThreadwiseReduceD0::Op::GetIdentityValue();
const auto d1_zeroVal = ThreadwiseReduceD1::Op::GetReductionZeroVal(); const auto d1_zeroVal = ThreadwiseReduceD1::Op::GetIdentityValue();
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
[&](auto i) { d0_thread_buf(i) = d0_zeroVal; }); [&](auto i) { d0_thread_buf(i) = d0_zeroVal; });
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
......
...@@ -76,7 +76,7 @@ struct SquaredAdd ...@@ -76,7 +76,7 @@ struct SquaredAdd
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); }; __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b * b; } __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b * b; }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment