Commit 480d6219 authored by ltqin's avatar ltqin
Browse files

finish one xdlops softmax

parent 52c5668f
......@@ -77,7 +77,7 @@ std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
os << "[";
for(size_t y = 0; y < matrix.mDesc.GetLengths()[1]; y++)
{
os << std::setw(5) << static_cast<float>(matrix(x, y));
os << std::setw(6) << std::setprecision(4) << static_cast<float>(matrix(x, y));
}
os << "]" << std::endl;
}
......
......@@ -263,6 +263,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; }
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
......
......@@ -13,6 +13,12 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
namespace ck {
......@@ -124,6 +130,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
using AccDataType = FloatAcc;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
......@@ -470,6 +477,92 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
{
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, 1, true> max_value_buf;
static_for<0, 1, 1>{}([&](auto I) {
max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
});
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, 1, true> accu_value_buf;
static_for<0, 1, 1>{}([&](auto I) {
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
constexpr auto c_thread_desc = blockwise_gemm.GetCThreadDesc();
// printf("c_thread_desc: {%d, %d, %d}", c_thread_desc.GetLength(I0).value,
// c_thread_desc.GetLength(I1).value, c_thread_desc.GetLength(I2));
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, 0, 0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{}, Number<c_thread_desc.GetLength(I2)>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})));
using ThreadwiseMaxReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
ThreadwiseMaxReduce::Reduce(xdlops_out.template AsType<float>(), max_value_buf);
// const index_t thread_local_id = get_thread_local_1d_id();
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
using ThreadClusterLengths_M_K = Sequence<32, 2>;
using ThreadClusterArrangeOrder = Sequence<1, 0>;
using BlockwiseMaxReduce = PartitionedBlockwiseReduction<
AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
block_sync_lds();
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0));
block_sync_lds();
// printf("\n");
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
// softmax
using BlockwiseSumReduce = PartitionedBlockwiseReduction<
AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
using ThreadwiseSumReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
static_for<0, c_thread_desc.GetLength(I2), 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
math::exp(xdlops_out.template AsType<float>()[iK] - max_value_buf(I0));
});
ThreadwiseSumReduce::Reduce(xdlops_out.template AsType<float>(), accu_value_buf);
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0));
block_sync_lds();
static_for<0, c_thread_desc.GetLength(I2), 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
xdlops_out.template AsType<float>()[iK] / accu_value_buf(I0);
});
}
// output: register to global memory
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment