Commit 730b98e1 authored by aska-0096's avatar aska-0096
Browse files

revert blkgemm pipe v2 changes.

parent d64030ed
...@@ -140,20 +140,15 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -140,20 +140,15 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
using Base::AMmaKStride; using Base::AMmaKStride;
using Base::BMmaKStride; using Base::BMmaKStride;
// static constexpr index_t WgpPerCU = static constexpr index_t WgpPerCU =
// (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1; (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
static constexpr index_t RegPerFetch =
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock / BlockSize / 4;
static constexpr index_t MaximumPrefetchStage = (256 / RegPerFetch) > 8 ? 8
: (256 / RegPerFetch);
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
92 * 1024, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); 32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
static constexpr index_t PrefetchStages = static constexpr index_t PrefetchStages =
FullMemBandPrefetchStages >= 2 ? FullMemBandPrefetchStages <= MaximumPrefetchStage FullMemBandPrefetchStages >= 2
? FullMemBandPrefetchStages ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
: MaximumPrefetchStage : 2;
: 2;
static constexpr index_t PrefillStages = 1; static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = PrefetchStages; static constexpr index_t GlobalBufferNum = PrefetchStages;
...@@ -635,10 +630,11 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -635,10 +630,11 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
// static constexpr index_t WgpPerCU = static constexpr index_t WgpPerCU =
// (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1; (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
92 * 1024, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); 32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
static constexpr index_t PrefetchStages = static constexpr index_t PrefetchStages =
FullMemBandPrefetchStages >= 2 FullMemBandPrefetchStages >= 2
? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment