Commit 9d709a68 authored by Adam Osewski's avatar Adam Osewski
Browse files

Add load tile overload which accepts output tensor as parameter.

* This give 8% perf boost at the cost of using more registers.
parent 93c30d2c
...@@ -31,6 +31,22 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT ...@@ -31,6 +31,22 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT
return tile_window.load(bool_constant<oob_conditional_check>{}); return tile_window.load(bool_constant<oob_conditional_check>{});
} }
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, bool_constant<oob_conditional_check>{});
}
/** /**
* @brief Loads a tile of data using inline assembly. * @brief Loads a tile of data using inline assembly.
* *
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -284,15 +284,22 @@ struct tile_window_with_static_distribution ...@@ -284,15 +284,22 @@ struct tile_window_with_static_distribution
template <bool oob_conditional_check = true> template <bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const
{ {
using Traits = load_store_traits; constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
load(dst_tensor, bool_constant<oob_conditional_check>{});
return dst_tensor;
}
template <typename DistributedTensor, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t; using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys; using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{}; constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) { static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20 /// TODO: use structure binding (to be captured later) if compiled in C++20
...@@ -346,8 +353,6 @@ struct tile_window_with_static_distribution ...@@ -346,8 +353,6 @@ struct tile_window_with_static_distribution
} }
}); });
}); });
return dst_tensor;
} }
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false> template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
......
...@@ -152,10 +152,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -152,10 +152,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile, CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const SrcTileWindow& dram_tile_window) const
{ {
// TODO: we need to have an api of load_tile which takes as param output tile load_tile(dst_block_tile, dram_tile_window);
load_tile_raw(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock}); move_tile_window(dram_tile_window, {0, KPerBlock});
buffer_load_fence();
} }
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction> template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
...@@ -220,7 +218,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -220,7 +218,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(), a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>()); Policy::template MakeADramTileDistribution<Problem>());
a_copy_dram_window.init_raw();
// A LDS tile window for store // A LDS tile window for store
auto a_copy_lds_window = auto a_copy_lds_window =
...@@ -234,7 +231,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -234,7 +231,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(), b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>()); Policy::template MakeBDramTileDistribution<Problem>());
b_copy_dram_window.init_raw();
// B LDS tile window for store // B LDS tile window for store
auto b_copy_lds_window = auto b_copy_lds_window =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment