Commit 787626fb authored by ltqin's avatar ltqin
Browse files

add n repeate function

parent da047ec1
......@@ -17,17 +17,60 @@ template <index_t BlockSize,
index_t NPerXDL,
index_t RegSizePerXdlops,
index_t MRepeat,
index_t NRepeat,
index_t MThreadSliceSize,
index_t NThreadSliceSize>
index_t NRepeat>
struct BlockwiseSoftmax_V1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static_assert(MRepeat == 1, "Now MRepeat must equal 1");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t MThreadSliceSize = 1;
static constexpr index_t WaveSize = 64;
constexpr static auto c_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<RegSizePerXdlops>{}));
using ThreadReduceSrcDesc_M_K = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<RegSizePerXdlops>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})));
using ThreadwiseMaxReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
using ThreadClusterLengths_M_K = Sequence<MPerXDL, WaveSize / MPerXDL>;
using ThreadClusterArrangeOrder = Sequence<1, 0>;
using BlockwiseMaxReduce =
PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
using BlockwiseSumReduce =
PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
using ThreadwiseSumReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
template <typename CThreadBuffer>
__host__ __device__ static void Run(CThreadBuffer& c_thread_buf, void* __restrict__ p_shared)
{
......@@ -44,74 +87,60 @@ struct BlockwiseSoftmax_V1
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, 0, 0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{}, Number<c_thread_desc.GetLength(I2)>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})));
using ThreadwiseMaxReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
ThreadwiseMaxReduce::Reduce(xdlops_out.template AsType<float>(), max_value_buf);
// const index_t thread_local_id = get_thread_local_1d_id();
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
// max value for one thread
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
using ThreadClusterLengths_M_K = Sequence<32, 2>;
using ThreadClusterArrangeOrder = Sequence<1, 0>;
using BlockwiseMaxReduce = PartitionedBlockwiseReduction<
AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
ThreadwiseMaxReduce::Reduce(xdlops_out.template AsType<float>(), max_value_buf);
});
//{const index_t thread_local_id = get_thread_local_1d_id();
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
// ignore = p_reduce_work_buffer;}
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
block_sync_lds();
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0));
block_sync_lds();
// printf("\n");
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
// {const index_t thread_local_id = get_thread_local_1d_id();
// printf("thread id: %d, Max: %f\t\t", thread_local_id, max_value_buf[I0]);}
// softmax
using BlockwiseSumReduce = PartitionedBlockwiseReduction<
AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
using ThreadwiseSumReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
static_for<0, c_thread_desc.GetLength(I2), 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
math::exp(xdlops_out.template AsType<float>()[iK] - max_value_buf(I0));
});
ThreadwiseSumReduce::Reduce(xdlops_out.template AsType<float>(), accu_value_buf);
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0));
block_sync_lds();
static_for<0, c_thread_desc.GetLength(I2), 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
xdlops_out.template AsType<float>()[iK] / accu_value_buf(I0);
});
{
// calculate exp for elements
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
static_for<0, RegSizePerXdlops, 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
math::exp(xdlops_out.template AsType<float>()[iK] - max_value_buf(I0));
});
});
// sum data
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
ThreadwiseSumReduce::Reduce(xdlops_out.template AsType<float>(), accu_value_buf);
block_sync_lds();
});
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0));
block_sync_lds();
// change elements
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
static_for<0, c_thread_desc.GetLength(I2), 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
xdlops_out.template AsType<float>()[iK] / accu_value_buf(I0);
});
});
}
}
};
}; // namespace ck
} // namespace ck
......@@ -480,9 +480,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
NPerXDL,
blockwise_gemm.GetRegSizePerXdlops(),
MXdlPerWave,
NXdlPerWave,
1,
1>;
NXdlPerWave>;
BlockwiseSoftmax::Run(c_thread_buf, p_reduce_work_buffer);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment