Commit 787626fb authored by ltqin's avatar ltqin
Browse files

add n repeate function

parent da047ec1
...@@ -17,17 +17,60 @@ template <index_t BlockSize, ...@@ -17,17 +17,60 @@ template <index_t BlockSize,
index_t NPerXDL, index_t NPerXDL,
index_t RegSizePerXdlops, index_t RegSizePerXdlops,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat>
index_t MThreadSliceSize,
index_t NThreadSliceSize>
struct BlockwiseSoftmax_V1 struct BlockwiseSoftmax_V1
{ {
static constexpr auto I0 = Number<0>{}; static_assert(MRepeat == 1, "Now MRepeat must equal 1");
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t MThreadSliceSize = 1;
static constexpr index_t WaveSize = 64;
constexpr static auto c_thread_desc = make_naive_tensor_descriptor_packed( constexpr static auto c_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<RegSizePerXdlops>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<RegSizePerXdlops>{}));
using ThreadReduceSrcDesc_M_K = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<RegSizePerXdlops>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})));
using ThreadwiseMaxReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
using ThreadClusterLengths_M_K = Sequence<MPerXDL, WaveSize / MPerXDL>;
using ThreadClusterArrangeOrder = Sequence<1, 0>;
using BlockwiseMaxReduce =
PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
using BlockwiseSumReduce =
PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
using ThreadwiseSumReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
template <typename CThreadBuffer> template <typename CThreadBuffer>
__host__ __device__ static void Run(CThreadBuffer& c_thread_buf, void* __restrict__ p_shared) __host__ __device__ static void Run(CThreadBuffer& c_thread_buf, void* __restrict__ p_shared)
{ {
...@@ -44,74 +87,60 @@ struct BlockwiseSoftmax_V1 ...@@ -44,74 +87,60 @@ struct BlockwiseSoftmax_V1
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>(); accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
}); });
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, 0, 0)); // max value for one thread
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{}); static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
make_tuple(Number<1>{}, Number<c_thread_desc.GetLength(I2)>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})));
using ThreadwiseMaxReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
ThreadwiseMaxReduce::Reduce(xdlops_out.template AsType<float>(), max_value_buf);
// const index_t thread_local_id = get_thread_local_1d_id();
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
using ThreadClusterLengths_M_K = Sequence<32, 2>; ThreadwiseMaxReduce::Reduce(xdlops_out.template AsType<float>(), max_value_buf);
using ThreadClusterArrangeOrder = Sequence<1, 0>; });
using BlockwiseMaxReduce = PartitionedBlockwiseReduction<
AccDataType, //{const index_t thread_local_id = get_thread_local_1d_id();
BlockSize, // printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
ThreadClusterLengths_M_K, // ignore = p_reduce_work_buffer;}
ThreadClusterArrangeOrder,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
auto reduce_work_buf = auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
block_sync_lds();
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0)); BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0));
block_sync_lds(); block_sync_lds();
// printf("\n"); // {const index_t thread_local_id = get_thread_local_1d_id();
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]); // printf("thread id: %d, Max: %f\t\t", thread_local_id, max_value_buf[I0]);}
// softmax // softmax
using BlockwiseSumReduce = PartitionedBlockwiseReduction< {
AccDataType, // calculate exp for elements
BlockSize, static_for<0, NRepeat, 1>{}([&](auto n) {
ThreadClusterLengths_M_K, constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
ThreadClusterArrangeOrder, auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
reduce::Add,
false, // ignored static_for<0, RegSizePerXdlops, 1>{}([&](auto iK) {
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>; xdlops_out.template AsType<float>()(iK) =
math::exp(xdlops_out.template AsType<float>()[iK] - max_value_buf(I0));
using ThreadwiseSumReduce = });
ThreadwiseReduction<AccDataType, });
ThreadReduceSrcDesc_M_K, // sum data
ThreadReduceDstDesc_M, static_for<0, NRepeat, 1>{}([&](auto n) {
reduce::Add, constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
false, // ignored auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>; ThreadwiseSumReduce::Reduce(xdlops_out.template AsType<float>(), accu_value_buf);
static_for<0, c_thread_desc.GetLength(I2), 1>{}([&](auto iK) { block_sync_lds();
xdlops_out.template AsType<float>()(iK) = });
math::exp(xdlops_out.template AsType<float>()[iK] - max_value_buf(I0)); BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0));
}); block_sync_lds();
ThreadwiseSumReduce::Reduce(xdlops_out.template AsType<float>(), accu_value_buf);
block_sync_lds(); // change elements
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0)); static_for<0, NRepeat, 1>{}([&](auto n) {
block_sync_lds(); constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
static_for<0, c_thread_desc.GetLength(I2), 1>{}([&](auto iK) { auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
xdlops_out.template AsType<float>()(iK) =
xdlops_out.template AsType<float>()[iK] / accu_value_buf(I0); static_for<0, c_thread_desc.GetLength(I2), 1>{}([&](auto iK) {
}); xdlops_out.template AsType<float>()(iK) =
xdlops_out.template AsType<float>()[iK] / accu_value_buf(I0);
});
});
}
} }
}; }; // namespace ck
} // namespace ck } // namespace ck
...@@ -480,9 +480,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -480,9 +480,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
NPerXDL, NPerXDL,
blockwise_gemm.GetRegSizePerXdlops(), blockwise_gemm.GetRegSizePerXdlops(),
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave>;
1,
1>;
BlockwiseSoftmax::Run(c_thread_buf, p_reduce_work_buffer); BlockwiseSoftmax::Run(c_thread_buf, p_reduce_work_buffer);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment