Commit e1e01d8f authored by ltqin's avatar ltqin
Browse files

add blockwsie softmax v1

parent 480d6219
...@@ -51,7 +51,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl ...@@ -51,7 +51,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
// clang-format on // clang-format on
using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
...@@ -92,8 +92,8 @@ int main(int argc, char* argv[]) ...@@ -92,8 +92,8 @@ int main(int argc, char* argv[])
bool time_kernel = false; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 32; ck::index_t M = 16;
ck::index_t N = 32; ck::index_t N = 16;
ck::index_t K = 64; ck::index_t K = 64;
ck::index_t StrideA = K; ck::index_t StrideA = K;
......
...@@ -263,7 +263,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -263,7 +263,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
} }
__host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; } __host__ __device__ static constexpr index_t GetRegSizePerXdlops()
{
return xdlops_gemm.GetRegSizePerXdlops();
}
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
namespace ck {
template <index_t BlockSize,
typename AccDataType,
index_t MPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t RegSizePerXdlops,
index_t MRepeat,
index_t NRepeat>
struct BlockwiseSoftmax_V1
{
static_assert(MRepeat == 1, "Now MRepeat must equal 1");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t MThreadSliceSize = 1;
static constexpr index_t WaveSize = 64;
static_assert(MPerBlock == MPerXDL * BlockSize / WaveSize, "wave is only m direction");
struct BlockToMKMap_M0_K_M1Adapt
{
__host__ __device__ BlockToMKMap_M0_K_M1Adapt() = default;
template <typename TopIdx>
__host__ __device__ static constexpr auto CalculateBottomIndex(const TopIdx& idx_top)
{
const auto index = idx_top[I0];
const auto m = (index / WaveSize) * MPerXDL + index % MPerXDL;
const auto k = (index % WaveSize) / MPerXDL;
return make_tuple(m, k);
}
};
constexpr static auto in_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<RegSizePerXdlops>{}));
using ThreadReduceSrcDesc_M_K = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<RegSizePerXdlops>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})));
using ThreadwiseMaxReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
using ThreadClusterLengths_M_K = Sequence<MPerBlock, WaveSize / MPerXDL>;
using BlockwiseMaxReduce =
PartitionedBlockwiseReduction2<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
BlockToMKMap_M0_K_M1Adapt,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
using BlockwiseSumReduce =
PartitionedBlockwiseReduction2<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
BlockToMKMap_M0_K_M1Adapt,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
using ThreadwiseSumReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
template <typename CThreadBuffer>
__host__ __device__ static void
Run(CThreadBuffer& in_thread_buf, float& f_sum, float& f_max, void* __restrict__ p_reduce)
{
auto reduce_work_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<AccDataType*>(p_reduce), BlockSize);
//
// find max value
//
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> max_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
});
// max value for one thread
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t in_offset = in_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = in_thread_buf.GetVectorTypeReference(Number<in_offset>{});
ThreadwiseMaxReduce::Reduce(xdlops_out.template AsType<float>(), max_value_buf);
});
// block reduce for max
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0));
block_sync_lds();
// save max
f_max = max_value_buf(I0);
//
// softmax
//
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
// calculate exp for elements, P=exp(s-max)
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t in_offset = in_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = in_thread_buf.GetVectorTypeReference(Number<in_offset>{});
static_for<0, RegSizePerXdlops, 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
math::exp(xdlops_out.template AsType<float>()[iK] - max_value_buf(I0));
});
});
// sum data
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t in_offset = in_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = in_thread_buf.GetVectorTypeReference(Number<in_offset>{});
ThreadwiseSumReduce::Reduce(xdlops_out.template AsType<float>(), accu_value_buf);
block_sync_lds();
});
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0));
block_sync_lds();
// save sum
f_sum = accu_value_buf(I0);
}
}; // namespace ck
} // namespace ck
...@@ -82,6 +82,78 @@ struct PartitionedBlockwiseReduction ...@@ -82,6 +82,78 @@ struct PartitionedBlockwiseReduction
}; };
}; };
// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
// 3) in_out_value is the input data in vgpr from each thread
// 4) in_out_value is the over-written reduced output in vgpr for each thread
// clang-format on
template <typename AccDataType,
index_t BlockSize,
typename ThreadClusterLengths_M_K,
typename ThreadClusterDesc,
typename OpReduce,
bool PropagateNan,
typename Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
struct PartitionedBlockwiseReduction2
{
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
"The product of cluster lengths should be same as BlockSize!");
static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<BufferLength_M>{}, Number<BufferLength_K>{}));
static constexpr auto thread_cluster_desc = ThreadClusterDesc{};
template <typename BufferType>
__device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
{
static_assert(is_same<typename BufferType::type, AccDataType>{},
"Buffer data type should be consistent as AccDataType!");
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
__syncthreads();
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
if(thread_k_cluster_id < indOffset)
{
index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset));
AccDataType opData1 = work_buffer[offset1];
AccDataType opData2 = work_buffer[offset2];
Accumulation::Calculate(opData1, opData2);
work_buffer(offset1) = opData1;
}
__syncthreads();
});
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
in_out_value = work_buffer[offset];
};
};
// clang-format off // clang-format off
// Assume: // Assume:
// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data // 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
......
...@@ -13,12 +13,7 @@ ...@@ -13,12 +13,7 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp" #include "ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
namespace ck { namespace ck {
...@@ -478,90 +473,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -478,90 +473,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
{ {
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
float f_sum, f_max;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, 1, true> max_value_buf;
static_for<0, 1, 1>{}([&](auto I) { using BlockwiseSoftmax = BlockwiseSoftmax_V1<BlockSize,
max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>(); FloatAcc,
}); MPerBlock,
MPerXDL,
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, 1, true> accu_value_buf; NPerXDL,
static_for<0, 1, 1>{}([&](auto I) { blockwise_gemm.GetRegSizePerXdlops(),
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>(); MXdlPerWave,
}); NXdlPerWave>;
BlockwiseSoftmax::Run(c_thread_buf, f_sum, f_max, p_reduce_work_buffer);
constexpr auto c_thread_desc = blockwise_gemm.GetCThreadDesc();
// printf("c_thread_desc: {%d, %d, %d}", c_thread_desc.GetLength(I0).value,
// c_thread_desc.GetLength(I1).value, c_thread_desc.GetLength(I2));
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, 0, 0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{}, Number<c_thread_desc.GetLength(I2)>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})));
using ThreadwiseMaxReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
ThreadwiseMaxReduce::Reduce(xdlops_out.template AsType<float>(), max_value_buf);
// const index_t thread_local_id = get_thread_local_1d_id();
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
using ThreadClusterLengths_M_K = Sequence<32, 2>;
using ThreadClusterArrangeOrder = Sequence<1, 0>;
using BlockwiseMaxReduce = PartitionedBlockwiseReduction<
AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
block_sync_lds();
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0));
block_sync_lds();
// printf("\n");
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
// softmax
using BlockwiseSumReduce = PartitionedBlockwiseReduction<
AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
using ThreadwiseSumReduce =
ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
static_for<0, c_thread_desc.GetLength(I2), 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
math::exp(xdlops_out.template AsType<float>()[iK] - max_value_buf(I0));
});
ThreadwiseSumReduce::Reduce(xdlops_out.template AsType<float>(), accu_value_buf);
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0));
block_sync_lds();
static_for<0, c_thread_desc.GetLength(I2), 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
xdlops_out.template AsType<float>()[iK] / accu_value_buf(I0);
});
} }
// output: register to global memory // output: register to global memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment