Commit 652728bc authored by ltqin's avatar ltqin
Browse files

regular code

parent c4b19b18
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
namespace ck { namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
typename AccDataType, typename AccDataType,
index_t MPerBlock,
index_t MPerXDL, index_t MPerXDL,
index_t NPerXDL, index_t NPerXDL,
index_t RegSizePerXdlops, index_t RegSizePerXdlops,
...@@ -22,12 +23,15 @@ struct BlockwiseSoftmax_V1 ...@@ -22,12 +23,15 @@ struct BlockwiseSoftmax_V1
{ {
static_assert(MRepeat == 1, "Now MRepeat must equal 1"); static_assert(MRepeat == 1, "Now MRepeat must equal 1");
static __shared__ AccDataType p_lex[MPerBlock];
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr index_t MThreadSliceSize = 1; static constexpr index_t MThreadSliceSize = 1;
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
static_assert(MPerBlock == MPerXDL * BlockSize / WaveSize, "wave is only m direction");
struct BlockToMKMap_M0_K_M1Adapt struct BlockToMKMap_M0_K_M1Adapt
{ {
__host__ __device__ BlockToMKMap_M0_K_M1Adapt() = default; __host__ __device__ BlockToMKMap_M0_K_M1Adapt() = default;
...@@ -57,7 +61,7 @@ struct BlockwiseSoftmax_V1 ...@@ -57,7 +61,7 @@ struct BlockwiseSoftmax_V1
false, // param ignored false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>; detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
using ThreadClusterLengths_M_K = Sequence<MPerXDL * BlockSize / WaveSize, WaveSize / MPerXDL>; using ThreadClusterLengths_M_K = Sequence<MPerBlock, WaveSize / MPerXDL>;
using BlockwiseMaxReduce = using BlockwiseMaxReduce =
PartitionedBlockwiseReduction2<AccDataType, PartitionedBlockwiseReduction2<AccDataType,
......
...@@ -476,6 +476,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -476,6 +476,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
using BlockwiseSoftmax = BlockwiseSoftmax_V1<BlockSize, using BlockwiseSoftmax = BlockwiseSoftmax_V1<BlockSize,
FloatAcc, FloatAcc,
MPerBlock,
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
blockwise_gemm.GetRegSizePerXdlops(), blockwise_gemm.GetRegSizePerXdlops(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment