Commit a085b740 authored by ltqin's avatar ltqin
Browse files

complete first verison

parent 652728bc
...@@ -23,10 +23,8 @@ struct BlockwiseSoftmax_V1 ...@@ -23,10 +23,8 @@ struct BlockwiseSoftmax_V1
{ {
static_assert(MRepeat == 1, "Now MRepeat must equal 1"); static_assert(MRepeat == 1, "Now MRepeat must equal 1");
static __shared__ AccDataType p_lex[MPerBlock];
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t MThreadSliceSize = 1; static constexpr index_t MThreadSliceSize = 1;
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
...@@ -36,8 +34,9 @@ struct BlockwiseSoftmax_V1 ...@@ -36,8 +34,9 @@ struct BlockwiseSoftmax_V1
{ {
__host__ __device__ BlockToMKMap_M0_K_M1Adapt() = default; __host__ __device__ BlockToMKMap_M0_K_M1Adapt() = default;
template <typename TopIdx> template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const __host__ __device__ static constexpr auto CalculateBottomIndex(const TopIdx& idx_top)
{ {
const auto index = idx_top[I0]; const auto index = idx_top[I0];
const auto m = (index / WaveSize) * MPerXDL + index % MPerXDL; const auto m = (index / WaveSize) * MPerXDL + index % MPerXDL;
const auto k = (index % WaveSize) / MPerXDL; const auto k = (index % WaveSize) / MPerXDL;
...@@ -45,6 +44,9 @@ struct BlockwiseSoftmax_V1 ...@@ -45,6 +44,9 @@ struct BlockwiseSoftmax_V1
} }
}; };
static constexpr auto softmax_buf_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerBlock>{}, Number<2>{}));
constexpr static auto in_thread_desc = make_naive_tensor_descriptor_packed( constexpr static auto in_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<RegSizePerXdlops>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<RegSizePerXdlops>{}));
...@@ -89,13 +91,22 @@ struct BlockwiseSoftmax_V1 ...@@ -89,13 +91,22 @@ struct BlockwiseSoftmax_V1
false, // ignored false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>; detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
template <typename CThreadBuffer> template <typename CThreadBuffer>
__host__ __device__ static void Run(CThreadBuffer& in_thread_buf, void* __restrict__ p_shared) __host__ __device__ static void
Run(CThreadBuffer& in_thread_buf, void* __restrict__ p_shared, void* __restrict__ p_softmax)
{ {
// printf("in_thread_desc: {%d, %d, %d}", in_thread_desc.GetLength(I0).value, // printf("in_thread_desc: {%d, %d, %d}", in_thread_desc.GetLength(I0).value,
// in_thread_desc.GetLength(I1).value, in_thread_desc.GetLength(I2).value); // in_thread_desc.GetLength(I1).value, in_thread_desc.GetLength(I2).value);
auto reduce_work_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto reduce_work_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<AccDataType*>(p_shared), BlockSize); static_cast<AccDataType*>(p_shared), BlockSize);
auto softmax_lds_buffer = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<AccDataType*>(p_softmax), MPerBlock * 2);
// static auto lds_buffer_m_k = GetSpaceForPreMax();
const index_t thread_local_id = get_thread_local_1d_id();
const auto thread_cluster_idx =
BlockToMKMap_M0_K_M1Adapt::CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
// const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
// //
// find max value // find max value
// //
...@@ -118,9 +129,11 @@ struct BlockwiseSoftmax_V1 ...@@ -118,9 +129,11 @@ struct BlockwiseSoftmax_V1
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0)); BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0));
block_sync_lds(); block_sync_lds();
// save max value
softmax_lds_buffer(softmax_buf_desc_m_k.CalculateOffset(
make_tuple(thread_m_cluster_id, 1))) = max_value_buf(I0);
// {const index_t thread_local_id = get_thread_local_1d_id(); printf("thread id: %d, Max: %f\t\t", thread_local_id, max_value_buf[I0]);
// printf("thread id: %d, Max: %f\t\t", thread_local_id, max_value_buf[I0]);}
// //
// softmax // softmax
...@@ -150,16 +163,20 @@ struct BlockwiseSoftmax_V1 ...@@ -150,16 +163,20 @@ struct BlockwiseSoftmax_V1
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0)); BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0));
block_sync_lds(); block_sync_lds();
// save sum
softmax_lds_buffer(softmax_buf_desc_m_k.CalculateOffset(
make_tuple(thread_m_cluster_id, 0))) = accu_value_buf(I0);
// change elements // change elements
static_for<0, NRepeat, 1>{}([&](auto n) { /* static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t in_offset = in_thread_desc.CalculateOffset(make_tuple(0, n, 0)); constexpr index_t in_offset = in_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = in_thread_buf.GetVectorTypeReference(Number<in_offset>{}); auto& xdlops_out =
in_thread_buf.GetVectorTypeReference(Number<in_offset>{});
static_for<0, in_thread_desc.GetLength(I2), 1>{}([&](auto iK) { static_for<0, RegSizePerXdlops, 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) = xdlops_out.template AsType<float>()(iK) =
xdlops_out.template AsType<float>()[iK] / accu_value_buf(I0); xdlops_out.template AsType<float>()[iK] / accu_value_buf(I0);
}); });
}); });*/
} }
}; // namespace ck }; // namespace ck
......
...@@ -474,6 +474,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -474,6 +474,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
num_k_block_main_loop); num_k_block_main_loop);
{ {
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
__shared__ AccDataType p_lex[MPerBlock * 2];
using BlockwiseSoftmax = BlockwiseSoftmax_V1<BlockSize, using BlockwiseSoftmax = BlockwiseSoftmax_V1<BlockSize,
FloatAcc, FloatAcc,
MPerBlock, MPerBlock,
...@@ -482,7 +484,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -482,7 +484,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
blockwise_gemm.GetRegSizePerXdlops(), blockwise_gemm.GetRegSizePerXdlops(),
MXdlPerWave, MXdlPerWave,
NXdlPerWave>; NXdlPerWave>;
BlockwiseSoftmax::Run(c_thread_buf, p_reduce_work_buffer); BlockwiseSoftmax::Run(c_thread_buf, p_reduce_work_buffer, p_lex);
} }
// output: register to global memory // output: register to global memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment