Commit 02f8c487 authored by carlushuang's avatar carlushuang
Browse files

add single issue api

parent bafb600b
...@@ -21,28 +21,32 @@ template <typename BottomTensorView_, ...@@ -21,28 +21,32 @@ template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true> bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_, CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
return tile_window.load(bool_constant<oob_conditional_check>{}); return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true> bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_, CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
LinearBottomDims_>& tile_window, LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
return tile_window.load(bool_constant<oob_conditional_check>{}); return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
template <typename T, template <typename T,
...@@ -50,6 +54,7 @@ template <typename T, ...@@ -50,6 +54,7 @@ template <typename T,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile, CK_TILE_DEVICE auto load_tile_raw(T& tile,
...@@ -57,10 +62,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, ...@@ -57,10 +62,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); tile_window.load_raw(
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
template <typename T, template <typename T,
...@@ -68,6 +75,7 @@ template <typename T, ...@@ -68,6 +75,7 @@ template <typename T,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile, CK_TILE_DEVICE auto load_tile_raw(T& tile,
...@@ -75,10 +83,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, ...@@ -75,10 +83,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
LinearBottomDims_>& tile_window, LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); tile_window.load_raw(
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
// for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem // for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem
...@@ -89,6 +99,7 @@ template <typename LdsTileWindow_, ...@@ -89,6 +99,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true> bool oob_conditional_check = true>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
async_load_tile(LdsTileWindow_&& lds_tile, async_load_tile(LdsTileWindow_&& lds_tile,
...@@ -96,9 +107,11 @@ async_load_tile(LdsTileWindow_&& lds_tile, ...@@ -96,9 +107,11 @@ async_load_tile(LdsTileWindow_&& lds_tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
return tile_window.async_load(lds_tile, bool_constant<oob_conditional_check>{}); return tile_window.async_load(
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
...@@ -106,15 +119,18 @@ template <typename LdsTileWindow_, ...@@ -106,15 +119,18 @@ template <typename LdsTileWindow_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true> bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
const tile_window_linear<BottomTensorView_, const tile_window_linear<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
LinearBottomDims_>& tile_window, LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
return tile_window.async_load(lds_tile, bool_constant<oob_conditional_check>{}); return tile_window.async_load(
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
...@@ -122,6 +138,7 @@ template <typename LdsTileWindow_, ...@@ -122,6 +138,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
...@@ -130,11 +147,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile, ...@@ -130,11 +147,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
return tile_window.async_load_raw( return tile_window.async_load_raw(lds_tile,
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
...@@ -142,6 +162,7 @@ template <typename LdsTileWindow_, ...@@ -142,6 +162,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
...@@ -149,27 +170,44 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, ...@@ -149,27 +170,44 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
LinearBottomDims_>& tile_window, LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
return tile_window.async_load_raw( return tile_window.async_load_raw(lds_tile,
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
} }
template <typename WindowLengths> template <typename WindowLengths, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&) CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{ {
return null_tensor{}; return null_tensor{};
} }
template <typename T, typename WindowLengths> template <typename T,
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<WindowLengths>&) typename WindowLengths,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/,
const null_tile_window<WindowLengths>&,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{ {
} }
// TODO: this function requires some sub-fileds exist for the target tile window // TODO: this function requires some sub-fileds exist for the target tile window
template <typename TileWindow, bool oob_conditional_check = true, bool pre_nop = false> template <typename TileWindow,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w, CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
...@@ -178,7 +216,7 @@ CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w, ...@@ -178,7 +216,7 @@ CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
auto t = make_static_distributed_tensor<DataType>(TileDstr{}); auto t = make_static_distributed_tensor<DataType>(TileDstr{});
load_tile_raw(t, w, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); load_tile_raw(t, w, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return t; return t;
} }
......
...@@ -18,10 +18,12 @@ namespace ck_tile { ...@@ -18,10 +18,12 @@ namespace ck_tile {
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void CK_TILE_DEVICE void
store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp, store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>; using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>; using TileDstr = remove_cvref_t<TileDistribution_>;
...@@ -35,16 +37,18 @@ store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& t ...@@ -35,16 +37,18 @@ store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& t
tile_window_tmp.get_window_origin(), tile_window_tmp.get_window_origin(),
tile_dstr); tile_dstr);
tile_window.store(dstr_tensor); tile_window.store(dstr_tensor, number<i_access>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp, store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>; using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>; using TileDstr = remove_cvref_t<TileDistribution_>;
...@@ -58,63 +62,71 @@ store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_ ...@@ -58,63 +62,71 @@ store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_
tile_window_tmp.get_window_origin(), tile_window_tmp.get_window_origin(),
tile_dstr); tile_dstr);
tile_window.store_raw(dstr_tensor); tile_window.store_raw(dstr_tensor, number<i_access>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void CK_TILE_DEVICE void
store_tile(tile_window_with_static_distribution<BottomTensorView_, store_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
tile_window.store(dstr_tensor); tile_window.store(dstr_tensor, number<i_access>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_distribution<BottomTensorView_, store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
tile_window.store_raw(dstr_tensor); tile_window.store_raw(dstr_tensor, number<i_access>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void store_tile( CK_TILE_DEVICE void store_tile(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>& tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window, tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
tile_window.store(dstr_tensor); tile_window.store(dstr_tensor, number<i_access>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void store_tile_raw( CK_TILE_DEVICE void store_tile_raw(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>& tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window, tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
tile_window.store_raw(dstr_tensor); tile_window.store_raw(dstr_tensor, number<i_access>{});
} }
} // namespace ck_tile } // namespace ck_tile
...@@ -18,6 +18,23 @@ ...@@ -18,6 +18,23 @@
namespace ck_tile { namespace ck_tile {
// TODO: NumCoord no need anymore?
#define WINDOW_DISPATCH_ISSUE_2() \
if constexpr(i_access < 0) \
{ \
static_for<0, NumCoord, 1>{}([&](auto iCoord) { \
static_for<0, NumAccessPerCoord, 1>{}( \
[&](auto iCoordAccess) { issue(iCoord, iCoordAccess); }); \
}); \
} \
else \
{ \
static_assert(i_access < (NumCoord * NumAccessPerCoord)); \
constexpr auto iCoordAccess = number<i_access % NumAccessPerCoord>{}; \
constexpr auto iCoord = number<i_access / NumAccessPerCoord>{}; \
issue(iCoord, iCoordAccess); \
}
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename StaticTileDistribution_, typename StaticTileDistribution_,
...@@ -283,8 +300,8 @@ struct tile_window_with_static_distribution ...@@ -283,8 +300,8 @@ struct tile_window_with_static_distribution
CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; } CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; }
template <bool oob_conditional_check = true> template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{ {
using Traits = load_store_traits; using Traits = load_store_traits;
...@@ -296,12 +313,11 @@ struct tile_window_with_static_distribution ...@@ -296,12 +313,11 @@ struct tile_window_with_static_distribution
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr); auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) { auto issue = [&](auto iCoord, auto iCoordAccess) {
/// TODO: use structure binding (to be captured later) if compiled in C++20 /// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...] // data index [y0, y1, ...]
...@@ -316,20 +332,17 @@ struct tile_window_with_static_distribution ...@@ -316,20 +332,17 @@ struct tile_window_with_static_distribution
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array( constexpr auto idx_ys = generate_array(
[&](auto jj) { [&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
: idx_ys_start[jj];
}, },
number<NDimY>{}); number<NDimY>{});
constexpr index_t d = constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() = dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j]; vec_value.template get_as<DataType>()[j];
}); });
#else #else
constexpr index_t d = constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % Traits::ScalarPerVector == 0); static_assert(d % Traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()( dst_tensor.get_thread_buffer().template get_as<vector_t>()(
...@@ -347,14 +360,19 @@ struct tile_window_with_static_distribution ...@@ -347,14 +360,19 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate( move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
} }
}); };
});
WINDOW_DISPATCH_ISSUE_2();
return dst_tensor; return dst_tensor;
} }
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false> template <typename DstTile,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const bool_constant<pre_nop> = {}) const
{ {
...@@ -377,12 +395,11 @@ struct tile_window_with_static_distribution ...@@ -377,12 +395,11 @@ struct tile_window_with_static_distribution
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer()); auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) { auto issue = [&](auto iCoord, auto iCoordAccess) {
/// TODO: use structure binding (to be captured later) if compiled in C++20 /// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() { constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
...@@ -393,8 +410,7 @@ struct tile_window_with_static_distribution ...@@ -393,8 +410,7 @@ struct tile_window_with_static_distribution
// data index [y0, y1, ...] // data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr index_t d = constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % Traits::ScalarPerVector == 0); static_assert(d % Traits::ScalarPerVector == 0);
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>( get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
...@@ -405,8 +421,7 @@ struct tile_window_with_static_distribution ...@@ -405,8 +421,7 @@ struct tile_window_with_static_distribution
pre_nop_); pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \ #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm volatile( asm volatile(""); // this is starting from rocm-6.2, but same sympton, reuse this flag
""); // this is starting from rocm-6.2, but same sympton, reuse this flag
#endif #endif
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
...@@ -419,17 +434,18 @@ struct tile_window_with_static_distribution ...@@ -419,17 +434,18 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate( move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
} }
}); };
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE WINDOW_DISPATCH_ISSUE_2();
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
"scratch memory" ::);
#endif
} }
// TODO: currently async load only implemented in inline asm // TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false> template <typename LdsTileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const bool_constant<pre_nop> = {}) const
{ {
...@@ -470,12 +486,11 @@ struct tile_window_with_static_distribution ...@@ -470,12 +486,11 @@ struct tile_window_with_static_distribution
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) { auto issue = [&](auto iCoord, auto iCoordAccess) {
// TODO: use structure binding (to be captured later) if compiled in C++20 // TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() { constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
...@@ -501,12 +516,14 @@ struct tile_window_with_static_distribution ...@@ -501,12 +516,14 @@ struct tile_window_with_static_distribution
m0_inc_with_memory(size_per_issue); m0_inc_with_memory(size_per_issue);
} }
}); };
});
WINDOW_DISPATCH_ISSUE_2();
} }
template <typename LdsTileWindow_, bool oob_conditional_check = true> template <typename LdsTileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>; using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
...@@ -544,12 +561,11 @@ struct tile_window_with_static_distribution ...@@ -544,12 +561,11 @@ struct tile_window_with_static_distribution
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value; lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) { auto issue = [&](auto iCoord, auto iCoordAccess) {
// TODO: use structure binding (to be captured later) if compiled in C++20 // TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// read from bottom tensor // read from bottom tensor
...@@ -569,12 +585,13 @@ struct tile_window_with_static_distribution ...@@ -569,12 +585,13 @@ struct tile_window_with_static_distribution
smem += size_per_issue; // Note we manually increase the per-issue offset smem += size_per_issue; // Note we manually increase the per-issue offset
} }
}); };
}); WINDOW_DISPATCH_ISSUE_2();
} }
template <bool oob_conditional_check = true> template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor, CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
using Traits = load_store_traits; using Traits = load_store_traits;
...@@ -586,12 +603,11 @@ struct tile_window_with_static_distribution ...@@ -586,12 +603,11 @@ struct tile_window_with_static_distribution
constexpr auto tile_dstr = TileDstr{}; constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) { auto issue = [&](auto iCoord, auto iCoordAccess) {
/// TODO: use structure binding (to be captured later) if compiled in C++20 /// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...] // data index [y0, y1, ...]
...@@ -604,13 +620,11 @@ struct tile_window_with_static_distribution ...@@ -604,13 +620,11 @@ struct tile_window_with_static_distribution
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array( constexpr auto idx_ys = generate_array(
[&](auto jj) { [&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
: idx_ys_start[jj];
}, },
number<NDimY>{}); number<NDimY>{});
constexpr index_t d = constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) = vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>(); dstr_tensor.get_thread_buffer().template at<d>();
...@@ -620,10 +634,7 @@ struct tile_window_with_static_distribution ...@@ -620,10 +634,7 @@ struct tile_window_with_static_distribution
// write into bottom tensor // write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>( get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, bottom_tensor_thread_coord, 0, vec_value, bool_constant<oob_conditional_check>{});
0,
vec_value,
bool_constant<oob_conditional_check>{});
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
...@@ -636,12 +647,13 @@ struct tile_window_with_static_distribution ...@@ -636,12 +647,13 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate( move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
} }
}); };
}); WINDOW_DISPATCH_ISSUE_2();
} }
CK_TILE_DEVICE void template <index_t i_access = -1>
store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor) const CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {}) const
{ {
using Traits = load_store_traits; using Traits = load_store_traits;
...@@ -652,12 +664,11 @@ struct tile_window_with_static_distribution ...@@ -652,12 +664,11 @@ struct tile_window_with_static_distribution
static constexpr bool oob_conditional_check = true; static constexpr bool oob_conditional_check = true;
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) { auto issue = [&](auto iCoord, auto iCoordAccess) {
/// TODO: use structure binding (to be captured later) if compiled in C++20 /// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...] // data index [y0, y1, ...]
...@@ -668,12 +679,10 @@ struct tile_window_with_static_distribution ...@@ -668,12 +679,10 @@ struct tile_window_with_static_distribution
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array( constexpr auto idx_ys = generate_array(
[&](auto jj) { [&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
: idx_ys_start[jj];
}, },
number<NDimY>{}); number<NDimY>{});
constexpr index_t d = constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) = vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>(); dstr_tensor.get_thread_buffer().template at<d>();
}); });
...@@ -694,12 +703,14 @@ struct tile_window_with_static_distribution ...@@ -694,12 +703,14 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate( move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
} }
}); };
});
WINDOW_DISPATCH_ISSUE_2();
} }
template <bool oob_conditional_check = true> template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor, CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
using Traits = load_store_traits; using Traits = load_store_traits;
...@@ -710,12 +721,11 @@ struct tile_window_with_static_distribution ...@@ -710,12 +721,11 @@ struct tile_window_with_static_distribution
constexpr auto tile_dstr = TileDstr{}; constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) { auto issue = [&](auto iCoord, auto iCoordAccess) {
/// TODO: use structure binding (to be captured later) if compiled in C++20 /// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...] // data index [y0, y1, ...]
...@@ -727,13 +737,11 @@ struct tile_window_with_static_distribution ...@@ -727,13 +737,11 @@ struct tile_window_with_static_distribution
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array( constexpr auto idx_ys = generate_array(
[&](auto jj) { [&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
: idx_ys_start[jj];
}, },
number<NDimY>{}); number<NDimY>{});
constexpr index_t d = constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) = vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>(); dstr_tensor.get_thread_buffer().template at<d>();
...@@ -741,10 +749,7 @@ struct tile_window_with_static_distribution ...@@ -741,10 +749,7 @@ struct tile_window_with_static_distribution
// write into bottom tensor // write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements<vector_t>( get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, bottom_tensor_thread_coord, 0, vec_value, bool_constant<oob_conditional_check>{});
0,
vec_value,
bool_constant<oob_conditional_check>{});
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
...@@ -757,8 +762,9 @@ struct tile_window_with_static_distribution ...@@ -757,8 +762,9 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate( move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
} }
}); };
});
WINDOW_DISPATCH_ISSUE_2();
} }
// move thread's botom tensor coordiante // move thread's botom tensor coordiante
...@@ -857,6 +863,8 @@ struct tile_window_with_static_distribution ...@@ -857,6 +863,8 @@ struct tile_window_with_static_distribution
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_; array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
}; };
#undef WINDOW_DISPATCH_ISSUE_2
// TODO: use strategy // TODO: use strategy
template <typename TensorView_, template <typename TensorView_,
typename WindowLengths_, typename WindowLengths_,
......
...@@ -18,6 +18,17 @@ ...@@ -18,6 +18,17 @@
namespace ck_tile { namespace ck_tile {
#define WINDOW_DISPATCH_ISSUE() \
if constexpr(i_access < 0) \
{ \
static_for<0, NumAccess, 1>{}([&](auto ia) { issue(ia); }); \
} \
else \
{ \
static_assert(i_access < NumAccess); \
issue(number<i_access>{}); \
}
// //
// This version of tile window will pre-cache offset/flags based on need // This version of tile window will pre-cache offset/flags based on need
// //
...@@ -443,8 +454,8 @@ struct tile_window_linear ...@@ -443,8 +454,8 @@ struct tile_window_linear
CK_TILE_DEVICE constexpr auto get_num_access() const { return traits::NumAccess; } CK_TILE_DEVICE constexpr auto get_num_access() const { return traits::NumAccess; }
template <bool oob_conditional_check = true> template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{ {
using vector_t = typename traits::vector_t; using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys; using SFC_Ys = typename traits::SFC_Ys;
...@@ -453,9 +464,8 @@ struct tile_window_linear ...@@ -453,9 +464,8 @@ struct tile_window_linear
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr); auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
// loop over thread tensor space [y0, y1, ...] auto issue = [&](auto i_access_) {
static_for<0, NumAccess, 1>{}([&](auto i_access) { constexpr auto IAccess = number<i_access_>{};
constexpr auto IAccess = number<i_access>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{}; constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
...@@ -494,17 +504,22 @@ struct tile_window_linear ...@@ -494,17 +504,22 @@ struct tile_window_linear
dst_tensor.get_thread_buffer().template get_as<vector_t>()( dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value); number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif #endif
}); };
WINDOW_DISPATCH_ISSUE();
return dst_tensor; return dst_tensor;
} }
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false> template <typename DstTile,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
number<i_access> = {}, // negative means loop over all num_access
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const bool_constant<pre_nop> = {}) const
{ {
using vector_t = typename traits::vector_t; using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys; using SFC_Ys = typename traits::SFC_Ys;
static constexpr index_t YElementSize = static constexpr index_t YElementSize =
...@@ -516,11 +531,10 @@ struct tile_window_linear ...@@ -516,11 +531,10 @@ struct tile_window_linear
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer()); auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
// loop over thread tensor space [y0, y1, ...] auto issue = [&](auto i_access_) {
static_for<0, NumAccess, 1>{}([&](auto i_access) { constexpr auto IAccess = number<i_access_>{};
constexpr auto IAccess = number<i_access>{};
constexpr auto pre_nop_ = [&]() { constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && i_access == 0 && if constexpr(pre_nop && i_access_ == 0 &&
BottomTensorView::buffer_view::get_address_space() == BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global) address_space_enum::global)
return bool_constant<true>{}; return bool_constant<true>{};
...@@ -550,16 +564,18 @@ struct tile_window_linear ...@@ -550,16 +564,18 @@ struct tile_window_linear
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm volatile(""); // this is starting from rocm-6.2, but same sympton, reuse this flag asm volatile(""); // this is starting from rocm-6.2, but same sympton, reuse this flag
#endif #endif
}); };
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile("; this inline asm is workaround to prevent compiler from using too much " WINDOW_DISPATCH_ISSUE();
"scratch memory" ::);
#endif
} }
// TODO: currently async load only implemented in inline asm // TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false> template <typename LdsTileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const bool_constant<pre_nop> = {}) const
{ {
...@@ -600,10 +616,10 @@ struct tile_window_linear ...@@ -600,10 +616,10 @@ struct tile_window_linear
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) { auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access>{}; constexpr auto IAccess = number<i_access_>{};
constexpr auto pre_nop_ = [&]() { constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && i_access == 0) if constexpr(pre_nop && i_access_ == 0)
return bool_constant<true>{}; return bool_constant<true>{};
else else
return bool_constant<false>{}; return bool_constant<false>{};
...@@ -618,15 +634,18 @@ struct tile_window_linear ...@@ -618,15 +634,18 @@ struct tile_window_linear
smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_); smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_);
// move thread coordinate // move thread coordinate
if constexpr(i_access != (NumAccess - 1)) if constexpr(i_access_ != (NumAccess - 1))
{ {
m0_inc_with_memory(size_per_issue); m0_inc_with_memory(size_per_issue);
} }
}); };
WINDOW_DISPATCH_ISSUE();
} }
template <typename LdsTileWindow_, bool oob_conditional_check = true> template <typename LdsTileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>; using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
...@@ -667,8 +686,8 @@ struct tile_window_linear ...@@ -667,8 +686,8 @@ struct tile_window_linear
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value; lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) { auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access>{}; constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{}; constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess]; auto bottom_tensor_flag = cached_flags_[IAccess];
...@@ -682,15 +701,18 @@ struct tile_window_linear ...@@ -682,15 +701,18 @@ struct tile_window_linear
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{});
// move thread coordinate // move thread coordinate
if constexpr(i_access != (NumAccess - 1)) if constexpr(i_access_ != (NumAccess - 1))
{ {
smem += size_per_issue; // Note we manually increase the per-issue offset smem += size_per_issue; // Note we manually increase the per-issue offset
} }
}); };
WINDOW_DISPATCH_ISSUE();
} }
template <bool oob_conditional_check = true> template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor, CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
...@@ -700,8 +722,8 @@ struct tile_window_linear ...@@ -700,8 +722,8 @@ struct tile_window_linear
constexpr auto tile_dstr = TileDstr{}; constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) { auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access>{}; constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{}; constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess); constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
...@@ -732,13 +754,15 @@ struct tile_window_linear ...@@ -732,13 +754,15 @@ struct tile_window_linear
bottom_tensor_flag, bottom_tensor_flag,
vec_value, vec_value,
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{});
}); };
WINDOW_DISPATCH_ISSUE();
} }
CK_TILE_DEVICE void template <index_t i_access = -1>
store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor) const CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {}) const
{ {
using vector_t = typename traits::vector_t; using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys; using SFC_Ys = typename traits::SFC_Ys;
...@@ -746,8 +770,8 @@ struct tile_window_linear ...@@ -746,8 +770,8 @@ struct tile_window_linear
static constexpr bool oob_conditional_check = true; static constexpr bool oob_conditional_check = true;
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) { auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access>{}; constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{}; constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess); constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
...@@ -773,11 +797,14 @@ struct tile_window_linear ...@@ -773,11 +797,14 @@ struct tile_window_linear
get_bottom_tensor_view() get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>( .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value); bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value);
}); };
WINDOW_DISPATCH_ISSUE();
} }
template <bool oob_conditional_check = true> template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor, CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
...@@ -787,8 +814,8 @@ struct tile_window_linear ...@@ -787,8 +814,8 @@ struct tile_window_linear
constexpr auto tile_dstr = TileDstr{}; constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) { auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access>{}; constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{}; constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess); constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
...@@ -820,7 +847,9 @@ struct tile_window_linear ...@@ -820,7 +847,9 @@ struct tile_window_linear
bottom_tensor_flag, bottom_tensor_flag,
vec_value, vec_value,
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{});
}); };
WINDOW_DISPATCH_ISSUE();
} }
// move thread's botom tensor coordiante // move thread's botom tensor coordiante
...@@ -920,6 +949,8 @@ struct tile_window_linear ...@@ -920,6 +949,8 @@ struct tile_window_linear
array<bool, traits::NumAccess> cached_flags_; array<bool, traits::NumAccess> cached_flags_;
}; };
#undef WINDOW_DISPATCH_ISSUE
namespace impl { namespace impl {
template <address_space_enum, index_t len_> template <address_space_enum, index_t len_>
struct default_linear_bottom_dims_impl struct default_linear_bottom_dims_impl
......
...@@ -17,10 +17,12 @@ namespace ck_tile { ...@@ -17,10 +17,12 @@ namespace ck_tile {
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void CK_TILE_DEVICE void
update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp, update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>; using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>; using TileDstr = remove_cvref_t<TileDistribution_>;
...@@ -34,22 +36,24 @@ update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& ...@@ -34,22 +36,24 @@ update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>&
tile_window_tmp.get_window_origin(), tile_window_tmp.get_window_origin(),
tile_dstr); tile_dstr);
tile_window.update(dstr_tensor); tile_window.update(dstr_tensor, number<i_access>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void CK_TILE_DEVICE void
update_tile(tile_window_with_static_distribution<BottomTensorView_, update_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
tile_window.update(dstr_tensor); tile_window.update(dstr_tensor, number<i_access>{});
} }
} // namespace ck_tile } // namespace ck_tile
...@@ -33,8 +33,8 @@ ...@@ -33,8 +33,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
//#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp" // #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp"
//#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp" // #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
......
...@@ -205,7 +205,7 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -205,7 +205,7 @@ struct BlockFmhaPipelineQRAsyncEx
"wrong!"); "wrong!");
// K tile in LDS // K tile in LDS
auto k_lds_store = [&](){ auto k_lds_store = [&]() {
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr); auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
return generate_tuple( return generate_tuple(
[&](auto i_buf) { [&](auto i_buf) {
...@@ -218,17 +218,16 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -218,17 +218,16 @@ struct BlockFmhaPipelineQRAsyncEx
number<Policy::NumPrefetchK>{}); number<Policy::NumPrefetchK>{});
}(); }();
auto k_lds_load = [&](){ auto k_lds_load = [&]() {
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr); auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
return make_tile_window( return make_tile_window(make_tensor_view<address_space_enum::lds>(
make_tensor_view<address_space_enum::lds>( k_lds_ptr, Policy::template MakeSmemLoadDesc_K<Problem>()),
k_lds_ptr, Policy::template MakeSmemLoadDesc_K<Problem>().get_lengths(),
Policy::template MakeSmemLoadDesc_K<Problem>()), {0, 0});
Policy::template MakeSmemLoadDesc_K<Problem>().get_lengths(), {0, 0});
}(); }();
// V tile in LDS // V tile in LDS
auto v_lds_store = [&](){ auto v_lds_store = [&]() {
auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr); auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr);
return generate_tuple( return generate_tuple(
[&](auto i_buf) { [&](auto i_buf) {
...@@ -241,13 +240,12 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -241,13 +240,12 @@ struct BlockFmhaPipelineQRAsyncEx
number<Policy::NumPrefetchV>{}); number<Policy::NumPrefetchV>{});
}(); }();
auto v_lds_load = [&](){ auto v_lds_load = [&]() {
auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr); auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr);
return make_tile_window( return make_tile_window(make_tensor_view<address_space_enum::lds>(
make_tensor_view<address_space_enum::lds>( v_lds_ptr, Policy::template MakeSmemLoadDesc_V<Problem>()),
v_lds_ptr, Policy::template MakeSmemLoadDesc_V<Problem>().get_lengths(),
Policy::template MakeSmemLoadDesc_V<Problem>()), {0, 0});
Policy::template MakeSmemLoadDesc_V<Problem>().get_lengths(), {0, 0});
}(); }();
// reduction function for softmax // reduction function for softmax
...@@ -258,16 +256,14 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -258,16 +256,14 @@ struct BlockFmhaPipelineQRAsyncEx
constexpr auto warp_gemm_0 = Policy::template GetWarpGemm_0<Problem>(); constexpr auto warp_gemm_0 = Policy::template GetWarpGemm_0<Problem>();
constexpr auto warp_gemm_1 = Policy::template GetWarpGemm_1<Problem>(); constexpr auto warp_gemm_1 = Policy::template GetWarpGemm_1<Problem>();
auto gemm_0 = [&](){ auto gemm_0 = [&]() {
constexpr index_t total_repeats = Repeat_M0 * Repeat_N0 * Repeat_K0; constexpr index_t total_repeats = Repeat_M0 * Repeat_N0 * Repeat_K0;
// n*k*m, more relaxed ds_read // n*k*m, more relaxed ds_read
static_for<0, total_repeats, 1>{}( static_for<0, total_repeats, 1>{}([&](auto i_r) {
[&](auto i_r){
constexpr index_t i_m = i_r % Repeat_M0; constexpr index_t i_m = i_r % Repeat_M0;
constexpr index_t i_k = (i_r / Repeat_M0) % Repeat_K0; constexpr index_t i_k = (i_r / Repeat_M0) % Repeat_K0;
constexpr index_t i_n = i_r / (Repeat_M0 * Repeat_K0); constexpr index_t i_n = i_r / (Repeat_M0 * Repeat_K0);
} });
);
}; };
auto q_dram_window = make_tile_window_raw(q_dram_block_window_tmp.get_bottom_tensor_view(), auto q_dram_window = make_tile_window_raw(q_dram_block_window_tmp.get_bottom_tensor_view(),
...@@ -285,7 +281,8 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -285,7 +281,8 @@ struct BlockFmhaPipelineQRAsyncEx
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(Policy::template MakeBlockGemmAccTile_0<Problem>()); using SaccBlockTileType = decltype(Policy::template MakeBlockGemmAccTile_0<Problem>());
auto s_accs = generate_tuple([&](auto) { return SaccBlockTileType{}; }, number<UnrollStages>{}); auto s_accs =
generate_tuple([&](auto) { return SaccBlockTileType{}; }, number<UnrollStages>{});
// infer Sacc, S, P, M, L, Oacc type // infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_accs)); using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_accs));
...@@ -296,7 +293,8 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -296,7 +293,8 @@ struct BlockFmhaPipelineQRAsyncEx
using OaccBlockTileType = decltype(Policy::template MakeBlockGemmAccTile_1<Problem>()); using OaccBlockTileType = decltype(Policy::template MakeBlockGemmAccTile_1<Problem>());
// init Oacc, M, L // init Oacc, M, L
auto o_accs = generate_tuple([&](auto) { return OaccBlockTileType{}; }, number<UnrollStages>{}); auto o_accs =
generate_tuple([&](auto) { return OaccBlockTileType{}; }, number<UnrollStages>{});
auto ms = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{}); auto ms = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{});
auto ls = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{}); auto ls = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{});
......
...@@ -105,14 +105,14 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -105,14 +105,14 @@ struct BlockFmhaPipelineQRAsyncEx
CK_TILE_HOST_DEVICE static constexpr auto MakeBlockGemmAccTile_0() CK_TILE_HOST_DEVICE static constexpr auto MakeBlockGemmAccTile_0()
{ {
using AccWarpDescEnc_ = typename decltype(GetWarpGemm_0())::CWarpDstrEncoding; using AccWarpDescEnc_ = typename decltype(GetWarpGemm_0())::CWarpDstrEncoding;
using BlockTile_ = sequence<Problem::BlockFmhaShape::Block_M0, Problem::BlockFmhaShape::Block_N0>; using BlockTile_ =
using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M0, Problem::BlockFmhaShape::BlockWarps_N0>; sequence<Problem::BlockFmhaShape::Block_M0, Problem::BlockFmhaShape::Block_N0>;
using WarpTile_ = sequence<Problem::BlockFmhaShape::Warp_M0, Problem::BlockFmhaShape::Warp_N0>; using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M0,
constexpr auto enc = make_block_gemm_acc_enc< Problem::BlockFmhaShape::BlockWarps_N0>;
AccWarpDescEnc_, using WarpTile_ =
BlockTile_, sequence<Problem::BlockFmhaShape::Warp_M0, Problem::BlockFmhaShape::Warp_N0>;
BlockWarps_, constexpr auto enc =
WarpTile_>(); make_block_gemm_acc_enc<AccWarpDescEnc_, BlockTile_, BlockWarps_, WarpTile_>();
constexpr auto dstr = make_static_tile_distribution(enc); constexpr auto dstr = make_static_tile_distribution(enc);
auto t = make_static_distributed_tensor<typename Problem::SaccDataType>(dstr); auto t = make_static_distributed_tensor<typename Problem::SaccDataType>(dstr);
return t; return t;
...@@ -443,8 +443,10 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -443,8 +443,10 @@ struct BlockFmhaPipelineQRAsyncEx
{ {
if constexpr(Problem::kHasDropout) if constexpr(Problem::kHasDropout)
{ {
constexpr index_t kMPerStep = Problem::BlockFmhaShape::BlockWarps_M0 * Problem::BlockFmhaShape::Warp_M0; constexpr index_t kMPerStep =
constexpr index_t kNPerStep = Problem::BlockFmhaShape::BlockWarps_N0 * Problem::BlockFmhaShape::Warp_N0; Problem::BlockFmhaShape::BlockWarps_M0 * Problem::BlockFmhaShape::Warp_M0;
constexpr index_t kNPerStep =
Problem::BlockFmhaShape::BlockWarps_N0 * Problem::BlockFmhaShape::Warp_N0;
return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t); return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t);
} }
...@@ -612,14 +614,14 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -612,14 +614,14 @@ struct BlockFmhaPipelineQRAsyncEx
CK_TILE_HOST_DEVICE static constexpr auto MakeBlockGemmAccTile_1() CK_TILE_HOST_DEVICE static constexpr auto MakeBlockGemmAccTile_1()
{ {
using AccWarpDescEnc_ = typename decltype(GetWarpGemm_1())::CWarpDstrEncoding; using AccWarpDescEnc_ = typename decltype(GetWarpGemm_1())::CWarpDstrEncoding;
using BlockTile_ = sequence<Problem::BlockFmhaShape::Block_M1, Problem::BlockFmhaShape::Block_N1>; using BlockTile_ =
using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M1, Problem::BlockFmhaShape::BlockWarps_N1>; sequence<Problem::BlockFmhaShape::Block_M1, Problem::BlockFmhaShape::Block_N1>;
using WarpTile_ = sequence<Problem::BlockFmhaShape::Warp_M1, Problem::BlockFmhaShape::Warp_N1>; using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M1,
constexpr auto enc = make_block_gemm_acc_enc< Problem::BlockFmhaShape::BlockWarps_N1>;
AccWarpDescEnc_, using WarpTile_ =
BlockTile_, sequence<Problem::BlockFmhaShape::Warp_M1, Problem::BlockFmhaShape::Warp_N1>;
BlockWarps_, constexpr auto enc =
WarpTile_>(); make_block_gemm_acc_enc<AccWarpDescEnc_, BlockTile_, BlockWarps_, WarpTile_>();
constexpr auto dstr = make_static_tile_distribution(enc); constexpr auto dstr = make_static_tile_distribution(enc);
auto t = make_static_distributed_tensor<typename Problem::OaccDataType>(dstr); auto t = make_static_distributed_tensor<typename Problem::OaccDataType>(dstr);
return t; return t;
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace ck_tile { namespace ck_tile {
template<typename AccWarpDescEnc, template <typename AccWarpDescEnc,
typename BlockTile, // seq<M, N> typename BlockTile, // seq<M, N>
typename BlockWarps, typename BlockWarps,
typename WarpTile> typename WarpTile>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp" #include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -56,7 +56,7 @@ struct TopkSoftmaxWarpPerRowPipeline ...@@ -56,7 +56,7 @@ struct TopkSoftmaxWarpPerRowPipeline
{ {
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW #if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
auto x = load_tile_raw(inp_win, bool_constant<true>{}, bool_constant<true>{}); auto x = load_tile_raw(inp_win, number<-1>{}, bool_constant<true>{}, bool_constant<true>{});
buffer_load_fence(number<0>{}); buffer_load_fence(number<0>{});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
#else #else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment