Commit 6ea43353 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fixes in pipeline.

parent 4f18c2de
...@@ -147,12 +147,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -147,12 +147,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
template <> template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> struct PipelineImpl<GemmPipelineScheduler::Intrawave>
{ {
template <typename BlockTile, typename SrcTileWindow> template <typename DstBlockTile, typename SrcTileWindow>
CK_TILE_DEVICE void GlobalPrefetch(BlockTile& block_tile, CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const SrcTileWindow& dram_tile_window) const
{ {
load_tile_raw(block_tile, dram_tile_window); // TODO: we need to have an api of load_tile which takes as param output tile
load_tile_raw(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock}); move_tile_window(dram_tile_window, {0, KPerBlock});
buffer_load_fence();
} }
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction> template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
...@@ -216,6 +218,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -216,6 +218,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(), a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>()); Policy::template MakeADramTileDistribution<Problem>());
a_copy_dram_window.init_raw();
// A LDS tile window for store // A LDS tile window for store
auto a_copy_lds_window = auto a_copy_lds_window =
make_tile_window(a_lds_block, make_tile_window(a_lds_block,
...@@ -228,6 +232,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -228,6 +232,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(), b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>()); Policy::template MakeBDramTileDistribution<Problem>());
b_copy_dram_window.init_raw();
// B LDS tile window for store // B LDS tile window for store
auto b_copy_lds_window = auto b_copy_lds_window =
make_tile_window(b_lds_block, make_tile_window(b_lds_block,
...@@ -283,7 +289,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -283,7 +289,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
// Global prefetch [2, PrefetchStages] // Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window); GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window); GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
...@@ -295,7 +301,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -295,7 +301,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
index_t i = 0; index_t i = 0;
do do
{ {
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
block_sync_lds(); block_sync_lds();
// block_gemm.LocalPrefetch(); // block_gemm.LocalPrefetch();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
...@@ -330,10 +336,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -330,10 +336,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds(); block_sync_lds();
LocalPrefill(a_copy_lds_window, LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_block_tiles.get(number<prefetch_idx>{}),
a_element_func); a_element_func);
LocalPrefill(b_copy_lds_window, LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), b_block_tiles.get(number<prefetch_idx>{}),
b_element_func); b_element_func);
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment