Commit 776721ab authored by Jing Zhang's avatar Jing Zhang
Browse files

tweak

parent 0e5848a4
...@@ -285,13 +285,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -285,13 +285,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
} }
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static CIndex __device__ static CIndex
CalculateCThreadOriginDataIndex(const index_t m0, const index_t n0, const index_t blk_i) CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{ {
const index_t waveId = get_thread_local_1d_id() / WaveSize; const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(blk_i); const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
const index_t waveId_m = waveId / NWaves; const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves; const index_t waveId_n = waveId % NWaves;
......
...@@ -302,14 +302,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -302,14 +302,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple(Sequence<0>{}, Sequence<1, 2>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize, BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline<BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_k_m0_m1_block_desc), decltype(a_k_m0_m1_block_desc),
decltype(b_k_n0_n1_block_desc), decltype(b_k_n0_n1_block_desc),
MPerWave, MPerWave,
NPerWave, NPerWave,
KPerWave>{}; KPerWave>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout(); constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = CLayout.GetBlkSize(); constexpr index_t BlkSize = CLayout.GetBlkSize();
......
...@@ -108,12 +108,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -108,12 +108,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerWave = 64; constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 64; constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmKPerWave = 4; constexpr index_t GemmKPerWave = 4;
constexpr index_t MRepeat = 1; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1; constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment