Commit 652728bc authored by ltqin's avatar ltqin
Browse files

regular code

parent c4b19b18
......@@ -13,6 +13,7 @@
namespace ck {
template <index_t BlockSize,
typename AccDataType,
index_t MPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t RegSizePerXdlops,
......@@ -22,12 +23,15 @@ struct BlockwiseSoftmax_V1
{
static_assert(MRepeat == 1, "Now MRepeat must equal 1");
static __shared__ AccDataType p_lex[MPerBlock];
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t MThreadSliceSize = 1;
static constexpr index_t WaveSize = 64;
static_assert(MPerBlock == MPerXDL * BlockSize / WaveSize, "wave is only m direction");
struct BlockToMKMap_M0_K_M1Adapt
{
__host__ __device__ BlockToMKMap_M0_K_M1Adapt() = default;
......@@ -57,7 +61,7 @@ struct BlockwiseSoftmax_V1
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
using ThreadClusterLengths_M_K = Sequence<MPerXDL * BlockSize / WaveSize, WaveSize / MPerXDL>;
using ThreadClusterLengths_M_K = Sequence<MPerBlock, WaveSize / MPerXDL>;
using BlockwiseMaxReduce =
PartitionedBlockwiseReduction2<AccDataType,
......
......@@ -476,6 +476,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
using BlockwiseSoftmax = BlockwiseSoftmax_V1<BlockSize,
FloatAcc,
MPerBlock,
MPerXDL,
NPerXDL,
blockwise_gemm.GetRegSizePerXdlops(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment