Commit 6ea43353 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fixes in pipeline.

parent 4f18c2de
......@@ -147,12 +147,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave>
{
template <typename BlockTile, typename SrcTileWindow>
CK_TILE_DEVICE void GlobalPrefetch(BlockTile& block_tile,
template <typename DstBlockTile, typename SrcTileWindow>
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const
{
load_tile_raw(block_tile, dram_tile_window);
// TODO: we need to have an api of load_tile which takes as param output tile
load_tile_raw(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock});
buffer_load_fence();
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
......@@ -216,6 +218,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
a_copy_dram_window.init_raw();
// A LDS tile window for store
auto a_copy_lds_window =
make_tile_window(a_lds_block,
......@@ -228,6 +232,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
b_copy_dram_window.init_raw();
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
......@@ -283,7 +289,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
// Global prefetch [2, PrefetchStages]
// Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
......@@ -295,7 +301,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
index_t i = 0;
do
{
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
......@@ -330,10 +336,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds();
LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_block_tiles.get(number<prefetch_idx>{}),
b_element_func);
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment