Commit 9d709a68 authored by Adam Osewski's avatar Adam Osewski
Browse files

Add load tile overload which accepts output tensor as parameter.

* This give 8% perf boost at the cost of using more registers.
parent 93c30d2c
......@@ -31,6 +31,22 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT
return tile_window.load(bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, bool_constant<oob_conditional_check>{});
}
/**
* @brief Loads a tile of data using inline assembly.
*
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -284,15 +284,22 @@ struct tile_window_with_static_distribution
template <bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
load(dst_tensor, bool_constant<oob_conditional_check>{});
return dst_tensor;
}
template <typename DistributedTensor, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
......@@ -346,8 +353,6 @@ struct tile_window_with_static_distribution
}
});
});
return dst_tensor;
}
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
......
......@@ -152,10 +152,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const
{
// TODO: we need to have an api of load_tile which takes as param output tile
load_tile_raw(dst_block_tile, dram_tile_window);
load_tile(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock});
buffer_load_fence();
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
......@@ -220,7 +218,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
a_copy_dram_window.init_raw();
// A LDS tile window for store
auto a_copy_lds_window =
......@@ -234,7 +231,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
b_copy_dram_window.init_raw();
// B LDS tile window for store
auto b_copy_lds_window =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment