Commit 02f8c487 authored by carlushuang's avatar carlushuang
Browse files

add single issue api

parent bafb600b
...@@ -21,28 +21,32 @@ template <typename BottomTensorView_, ...@@ -21,28 +21,32 @@ template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true> bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_, CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
return tile_window.load(bool_constant<oob_conditional_check>{}); return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true> bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_, CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
LinearBottomDims_>& tile_window, LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
return tile_window.load(bool_constant<oob_conditional_check>{}); return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
template <typename T, template <typename T,
...@@ -50,6 +54,7 @@ template <typename T, ...@@ -50,6 +54,7 @@ template <typename T,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile, CK_TILE_DEVICE auto load_tile_raw(T& tile,
...@@ -57,10 +62,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, ...@@ -57,10 +62,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); tile_window.load_raw(
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
template <typename T, template <typename T,
...@@ -68,6 +75,7 @@ template <typename T, ...@@ -68,6 +75,7 @@ template <typename T,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile, CK_TILE_DEVICE auto load_tile_raw(T& tile,
...@@ -75,10 +83,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, ...@@ -75,10 +83,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
LinearBottomDims_>& tile_window, LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); tile_window.load_raw(
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
// for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem // for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem
...@@ -89,6 +99,7 @@ template <typename LdsTileWindow_, ...@@ -89,6 +99,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true> bool oob_conditional_check = true>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
async_load_tile(LdsTileWindow_&& lds_tile, async_load_tile(LdsTileWindow_&& lds_tile,
...@@ -96,9 +107,11 @@ async_load_tile(LdsTileWindow_&& lds_tile, ...@@ -96,9 +107,11 @@ async_load_tile(LdsTileWindow_&& lds_tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
return tile_window.async_load(lds_tile, bool_constant<oob_conditional_check>{}); return tile_window.async_load(
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
...@@ -106,15 +119,18 @@ template <typename LdsTileWindow_, ...@@ -106,15 +119,18 @@ template <typename LdsTileWindow_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true> bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
const tile_window_linear<BottomTensorView_, const tile_window_linear<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
LinearBottomDims_>& tile_window, LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
return tile_window.async_load(lds_tile, bool_constant<oob_conditional_check>{}); return tile_window.async_load(
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
...@@ -122,6 +138,7 @@ template <typename LdsTileWindow_, ...@@ -122,6 +138,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
...@@ -130,11 +147,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile, ...@@ -130,11 +147,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
return tile_window.async_load_raw( return tile_window.async_load_raw(lds_tile,
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
...@@ -142,6 +162,7 @@ template <typename LdsTileWindow_, ...@@ -142,6 +162,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
...@@ -149,27 +170,44 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, ...@@ -149,27 +170,44 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
LinearBottomDims_>& tile_window, LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
return tile_window.async_load_raw( return tile_window.async_load_raw(lds_tile,
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
} }
template <typename WindowLengths> template <typename WindowLengths, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&) CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{ {
return null_tensor{}; return null_tensor{};
} }
template <typename T, typename WindowLengths> template <typename T,
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<WindowLengths>&) typename WindowLengths,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/,
const null_tile_window<WindowLengths>&,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{ {
} }
// TODO: this function requires some sub-fileds exist for the target tile window // TODO: this function requires some sub-fileds exist for the target tile window
template <typename TileWindow, bool oob_conditional_check = true, bool pre_nop = false> template <typename TileWindow,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w, CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
...@@ -178,7 +216,7 @@ CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w, ...@@ -178,7 +216,7 @@ CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
auto t = make_static_distributed_tensor<DataType>(TileDstr{}); auto t = make_static_distributed_tensor<DataType>(TileDstr{});
load_tile_raw(t, w, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); load_tile_raw(t, w, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return t; return t;
} }
......
...@@ -18,10 +18,12 @@ namespace ck_tile { ...@@ -18,10 +18,12 @@ namespace ck_tile {
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void CK_TILE_DEVICE void
store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp, store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>; using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>; using TileDstr = remove_cvref_t<TileDistribution_>;
...@@ -35,16 +37,18 @@ store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& t ...@@ -35,16 +37,18 @@ store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& t
tile_window_tmp.get_window_origin(), tile_window_tmp.get_window_origin(),
tile_dstr); tile_dstr);
tile_window.store(dstr_tensor); tile_window.store(dstr_tensor, number<i_access>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp, store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>; using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>; using TileDstr = remove_cvref_t<TileDistribution_>;
...@@ -58,63 +62,71 @@ store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_ ...@@ -58,63 +62,71 @@ store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_
tile_window_tmp.get_window_origin(), tile_window_tmp.get_window_origin(),
tile_dstr); tile_dstr);
tile_window.store_raw(dstr_tensor); tile_window.store_raw(dstr_tensor, number<i_access>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void CK_TILE_DEVICE void
store_tile(tile_window_with_static_distribution<BottomTensorView_, store_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
tile_window.store(dstr_tensor); tile_window.store(dstr_tensor, number<i_access>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_distribution<BottomTensorView_, store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
tile_window.store_raw(dstr_tensor); tile_window.store_raw(dstr_tensor, number<i_access>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void store_tile( CK_TILE_DEVICE void store_tile(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>& tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window, tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
tile_window.store(dstr_tensor); tile_window.store(dstr_tensor, number<i_access>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename LinearBottomDims_, typename LinearBottomDims_,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void store_tile_raw( CK_TILE_DEVICE void store_tile_raw(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>& tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window, tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
tile_window.store_raw(dstr_tensor); tile_window.store_raw(dstr_tensor, number<i_access>{});
} }
} // namespace ck_tile } // namespace ck_tile
This diff is collapsed.
...@@ -18,6 +18,17 @@ ...@@ -18,6 +18,17 @@
namespace ck_tile { namespace ck_tile {
#define WINDOW_DISPATCH_ISSUE() \
if constexpr(i_access < 0) \
{ \
static_for<0, NumAccess, 1>{}([&](auto ia) { issue(ia); }); \
} \
else \
{ \
static_assert(i_access < NumAccess); \
issue(number<i_access>{}); \
}
// //
// This version of tile window will pre-cache offset/flags based on need // This version of tile window will pre-cache offset/flags based on need
// //
...@@ -443,8 +454,8 @@ struct tile_window_linear ...@@ -443,8 +454,8 @@ struct tile_window_linear
CK_TILE_DEVICE constexpr auto get_num_access() const { return traits::NumAccess; } CK_TILE_DEVICE constexpr auto get_num_access() const { return traits::NumAccess; }
template <bool oob_conditional_check = true> template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{ {
using vector_t = typename traits::vector_t; using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys; using SFC_Ys = typename traits::SFC_Ys;
...@@ -453,9 +464,8 @@ struct tile_window_linear ...@@ -453,9 +464,8 @@ struct tile_window_linear
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr); auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
// loop over thread tensor space [y0, y1, ...] auto issue = [&](auto i_access_) {
static_for<0, NumAccess, 1>{}([&](auto i_access) { constexpr auto IAccess = number<i_access_>{};
constexpr auto IAccess = number<i_access>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{}; constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
...@@ -494,17 +504,22 @@ struct tile_window_linear ...@@ -494,17 +504,22 @@ struct tile_window_linear
dst_tensor.get_thread_buffer().template get_as<vector_t>()( dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value); number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif #endif
}); };
WINDOW_DISPATCH_ISSUE();
return dst_tensor; return dst_tensor;
} }
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false> template <typename DstTile,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
number<i_access> = {}, // negative means loop over all num_access
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const bool_constant<pre_nop> = {}) const
{ {
using vector_t = typename traits::vector_t; using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys; using SFC_Ys = typename traits::SFC_Ys;
static constexpr index_t YElementSize = static constexpr index_t YElementSize =
...@@ -516,11 +531,10 @@ struct tile_window_linear ...@@ -516,11 +531,10 @@ struct tile_window_linear
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer()); auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
// loop over thread tensor space [y0, y1, ...] auto issue = [&](auto i_access_) {
static_for<0, NumAccess, 1>{}([&](auto i_access) { constexpr auto IAccess = number<i_access_>{};
constexpr auto IAccess = number<i_access>{};
constexpr auto pre_nop_ = [&]() { constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && i_access == 0 && if constexpr(pre_nop && i_access_ == 0 &&
BottomTensorView::buffer_view::get_address_space() == BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global) address_space_enum::global)
return bool_constant<true>{}; return bool_constant<true>{};
...@@ -550,16 +564,18 @@ struct tile_window_linear ...@@ -550,16 +564,18 @@ struct tile_window_linear
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm volatile(""); // this is starting from rocm-6.2, but same sympton, reuse this flag asm volatile(""); // this is starting from rocm-6.2, but same sympton, reuse this flag
#endif #endif
}); };
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile("; this inline asm is workaround to prevent compiler from using too much " WINDOW_DISPATCH_ISSUE();
"scratch memory" ::);
#endif
} }
// TODO: currently async load only implemented in inline asm // TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false> template <typename LdsTileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const bool_constant<pre_nop> = {}) const
{ {
...@@ -600,10 +616,10 @@ struct tile_window_linear ...@@ -600,10 +616,10 @@ struct tile_window_linear
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) { auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access>{}; constexpr auto IAccess = number<i_access_>{};
constexpr auto pre_nop_ = [&]() { constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && i_access == 0) if constexpr(pre_nop && i_access_ == 0)
return bool_constant<true>{}; return bool_constant<true>{};
else else
return bool_constant<false>{}; return bool_constant<false>{};
...@@ -618,15 +634,18 @@ struct tile_window_linear ...@@ -618,15 +634,18 @@ struct tile_window_linear
smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_); smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_);
// move thread coordinate // move thread coordinate
if constexpr(i_access != (NumAccess - 1)) if constexpr(i_access_ != (NumAccess - 1))
{ {
m0_inc_with_memory(size_per_issue); m0_inc_with_memory(size_per_issue);
} }
}); };
WINDOW_DISPATCH_ISSUE();
} }
template <typename LdsTileWindow_, bool oob_conditional_check = true> template <typename LdsTileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>; using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
...@@ -667,8 +686,8 @@ struct tile_window_linear ...@@ -667,8 +686,8 @@ struct tile_window_linear
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value; lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) { auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access>{}; constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{}; constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess]; auto bottom_tensor_flag = cached_flags_[IAccess];
...@@ -682,15 +701,18 @@ struct tile_window_linear ...@@ -682,15 +701,18 @@ struct tile_window_linear
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{});
// move thread coordinate // move thread coordinate
if constexpr(i_access != (NumAccess - 1)) if constexpr(i_access_ != (NumAccess - 1))
{ {
smem += size_per_issue; // Note we manually increase the per-issue offset smem += size_per_issue; // Note we manually increase the per-issue offset
} }
}); };
WINDOW_DISPATCH_ISSUE();
} }
template <bool oob_conditional_check = true> template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor, CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
...@@ -700,8 +722,8 @@ struct tile_window_linear ...@@ -700,8 +722,8 @@ struct tile_window_linear
constexpr auto tile_dstr = TileDstr{}; constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) { auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access>{}; constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{}; constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess); constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
...@@ -732,13 +754,15 @@ struct tile_window_linear ...@@ -732,13 +754,15 @@ struct tile_window_linear
bottom_tensor_flag, bottom_tensor_flag,
vec_value, vec_value,
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{});
}); };
WINDOW_DISPATCH_ISSUE();
} }
CK_TILE_DEVICE void template <index_t i_access = -1>
store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor) const CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {}) const
{ {
using vector_t = typename traits::vector_t; using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys; using SFC_Ys = typename traits::SFC_Ys;
...@@ -746,8 +770,8 @@ struct tile_window_linear ...@@ -746,8 +770,8 @@ struct tile_window_linear
static constexpr bool oob_conditional_check = true; static constexpr bool oob_conditional_check = true;
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) { auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access>{}; constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{}; constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess); constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
...@@ -773,11 +797,14 @@ struct tile_window_linear ...@@ -773,11 +797,14 @@ struct tile_window_linear
get_bottom_tensor_view() get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>( .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value); bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value);
}); };
WINDOW_DISPATCH_ISSUE();
} }
template <bool oob_conditional_check = true> template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor, CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
...@@ -787,8 +814,8 @@ struct tile_window_linear ...@@ -787,8 +814,8 @@ struct tile_window_linear
constexpr auto tile_dstr = TileDstr{}; constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) { auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access>{}; constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{}; constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess); constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
...@@ -820,7 +847,9 @@ struct tile_window_linear ...@@ -820,7 +847,9 @@ struct tile_window_linear
bottom_tensor_flag, bottom_tensor_flag,
vec_value, vec_value,
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{});
}); };
WINDOW_DISPATCH_ISSUE();
} }
// move thread's botom tensor coordiante // move thread's botom tensor coordiante
...@@ -920,6 +949,8 @@ struct tile_window_linear ...@@ -920,6 +949,8 @@ struct tile_window_linear
array<bool, traits::NumAccess> cached_flags_; array<bool, traits::NumAccess> cached_flags_;
}; };
#undef WINDOW_DISPATCH_ISSUE
namespace impl { namespace impl {
template <address_space_enum, index_t len_> template <address_space_enum, index_t len_>
struct default_linear_bottom_dims_impl struct default_linear_bottom_dims_impl
......
...@@ -17,10 +17,12 @@ namespace ck_tile { ...@@ -17,10 +17,12 @@ namespace ck_tile {
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void CK_TILE_DEVICE void
update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp, update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>; using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>; using TileDstr = remove_cvref_t<TileDistribution_>;
...@@ -34,22 +36,24 @@ update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& ...@@ -34,22 +36,24 @@ update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>&
tile_window_tmp.get_window_origin(), tile_window_tmp.get_window_origin(),
tile_dstr); tile_dstr);
tile_window.update(dstr_tensor); tile_window.update(dstr_tensor, number<i_access>{});
} }
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
typename DataType_> typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void CK_TILE_DEVICE void
update_tile(tile_window_with_static_distribution<BottomTensorView_, update_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor) const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{ {
tile_window.update(dstr_tensor); tile_window.update(dstr_tensor, number<i_access>{});
} }
} // namespace ck_tile } // namespace ck_tile
...@@ -33,8 +33,8 @@ ...@@ -33,8 +33,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
//#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp" // #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp"
//#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp" // #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
......
...@@ -45,33 +45,33 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -45,33 +45,33 @@ struct BlockFmhaPipelineQRAsyncEx
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
static constexpr index_t Block_M0 = BlockFmhaShape::Block_M0; static constexpr index_t Block_M0 = BlockFmhaShape::Block_M0;
static constexpr index_t Block_N0 = BlockFmhaShape::Block_N0; static constexpr index_t Block_N0 = BlockFmhaShape::Block_N0;
static constexpr index_t Block_K0 = BlockFmhaShape::Block_K0; static constexpr index_t Block_K0 = BlockFmhaShape::Block_K0;
static constexpr index_t BlockWarps_M0 = BlockFmhaShape::BlockWarps_M0; static constexpr index_t BlockWarps_M0 = BlockFmhaShape::BlockWarps_M0;
static constexpr index_t BlockWarps_N0 = BlockFmhaShape::BlockWarps_N0; static constexpr index_t BlockWarps_N0 = BlockFmhaShape::BlockWarps_N0;
static constexpr index_t BlockWarps_K0 = BlockFmhaShape::BlockWarps_K0; static constexpr index_t BlockWarps_K0 = BlockFmhaShape::BlockWarps_K0;
static constexpr index_t Warps_M0 = BlockFmhaShape::Warps_M0; static constexpr index_t Warps_M0 = BlockFmhaShape::Warps_M0;
static constexpr index_t Warps_N0 = BlockFmhaShape::Warps_N0; static constexpr index_t Warps_N0 = BlockFmhaShape::Warps_N0;
static constexpr index_t Warps_K0 = BlockFmhaShape::Warps_K0; static constexpr index_t Warps_K0 = BlockFmhaShape::Warps_K0;
static constexpr index_t Repeat_M0 = BlockFmhaShape::Repeat_M0; static constexpr index_t Repeat_M0 = BlockFmhaShape::Repeat_M0;
static constexpr index_t Repeat_N0 = BlockFmhaShape::Repeat_N0; static constexpr index_t Repeat_N0 = BlockFmhaShape::Repeat_N0;
static constexpr index_t Repeat_K0 = BlockFmhaShape::Repeat_K0; static constexpr index_t Repeat_K0 = BlockFmhaShape::Repeat_K0;
static constexpr index_t Block_M1 = BlockFmhaShape::Block_M1; static constexpr index_t Block_M1 = BlockFmhaShape::Block_M1;
static constexpr index_t Block_N1 = BlockFmhaShape::Block_N1; static constexpr index_t Block_N1 = BlockFmhaShape::Block_N1;
static constexpr index_t Block_K1 = BlockFmhaShape::Block_K1; static constexpr index_t Block_K1 = BlockFmhaShape::Block_K1;
static constexpr index_t BlockWarps_M1 = BlockFmhaShape::BlockWarps_M1; static constexpr index_t BlockWarps_M1 = BlockFmhaShape::BlockWarps_M1;
static constexpr index_t BlockWarps_N1 = BlockFmhaShape::BlockWarps_N1; static constexpr index_t BlockWarps_N1 = BlockFmhaShape::BlockWarps_N1;
static constexpr index_t BlockWarps_K1 = BlockFmhaShape::BlockWarps_K1; static constexpr index_t BlockWarps_K1 = BlockFmhaShape::BlockWarps_K1;
static constexpr index_t Warps_M1 = BlockFmhaShape::Warps_M1; static constexpr index_t Warps_M1 = BlockFmhaShape::Warps_M1;
static constexpr index_t Warps_N1 = BlockFmhaShape::Warps_N1; static constexpr index_t Warps_N1 = BlockFmhaShape::Warps_N1;
static constexpr index_t Warps_K1 = BlockFmhaShape::Warps_K1; static constexpr index_t Warps_K1 = BlockFmhaShape::Warps_K1;
static constexpr index_t Repeat_M1 = BlockFmhaShape::Repeat_M1; static constexpr index_t Repeat_M1 = BlockFmhaShape::Repeat_M1;
static constexpr index_t Repeat_N1 = BlockFmhaShape::Repeat_N1; static constexpr index_t Repeat_N1 = BlockFmhaShape::Repeat_N1;
static constexpr index_t Repeat_K1 = BlockFmhaShape::Repeat_K1; static constexpr index_t Repeat_K1 = BlockFmhaShape::Repeat_K1;
static constexpr index_t UnrollStages = 2; // pipeline unroll the gemm/softmax/gemm static constexpr index_t UnrollStages = 2; // pipeline unroll the gemm/softmax/gemm
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
...@@ -205,49 +205,47 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -205,49 +205,47 @@ struct BlockFmhaPipelineQRAsyncEx
"wrong!"); "wrong!");
// K tile in LDS // K tile in LDS
auto k_lds_store = [&](){ auto k_lds_store = [&]() {
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr); auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
return generate_tuple( return generate_tuple(
[&](auto i_buf) { [&](auto i_buf) {
return make_tile_window( return make_tile_window(
make_tensor_view<address_space_enum::lds>( make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeSmemStoreDesc_K<Problem>(i_buf)), k_lds_ptr, Policy::template MakeSmemStoreDesc_K<Problem>(i_buf)),
Policy::template MakeSmemStoreDesc_K<Problem>(i_buf).get_lengths(), Policy::template MakeSmemStoreDesc_K<Problem>(i_buf).get_lengths(),
{0, 0, 0}); {0, 0, 0});
}, },
number<Policy::NumPrefetchK>{}); number<Policy::NumPrefetchK>{});
}(); }();
auto k_lds_load = [&](){ auto k_lds_load = [&]() {
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr); auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
return make_tile_window( return make_tile_window(make_tensor_view<address_space_enum::lds>(
make_tensor_view<address_space_enum::lds>( k_lds_ptr, Policy::template MakeSmemLoadDesc_K<Problem>()),
k_lds_ptr, Policy::template MakeSmemLoadDesc_K<Problem>().get_lengths(),
Policy::template MakeSmemLoadDesc_K<Problem>()), {0, 0});
Policy::template MakeSmemLoadDesc_K<Problem>().get_lengths(), {0, 0});
}(); }();
// V tile in LDS // V tile in LDS
auto v_lds_store = [&](){ auto v_lds_store = [&]() {
auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr); auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr);
return generate_tuple( return generate_tuple(
[&](auto i_buf) { [&](auto i_buf) {
return make_tile_window( return make_tile_window(
make_tensor_view<address_space_enum::lds>( make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeSmemStoreDesc_V<Problem>(i_buf)), v_lds_ptr, Policy::template MakeSmemStoreDesc_V<Problem>(i_buf)),
Policy::template MakeSmemStoreDesc_V<Problem>(i_buf).get_lengths(), Policy::template MakeSmemStoreDesc_V<Problem>(i_buf).get_lengths(),
{0, 0, 0}); {0, 0, 0});
}, },
number<Policy::NumPrefetchV>{}); number<Policy::NumPrefetchV>{});
}(); }();
auto v_lds_load = [&](){ auto v_lds_load = [&]() {
auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr); auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr);
return make_tile_window( return make_tile_window(make_tensor_view<address_space_enum::lds>(
make_tensor_view<address_space_enum::lds>( v_lds_ptr, Policy::template MakeSmemLoadDesc_V<Problem>()),
v_lds_ptr, Policy::template MakeSmemLoadDesc_V<Problem>().get_lengths(),
Policy::template MakeSmemLoadDesc_V<Problem>()), {0, 0});
Policy::template MakeSmemLoadDesc_V<Problem>().get_lengths(), {0, 0});
}(); }();
// reduction function for softmax // reduction function for softmax
...@@ -258,22 +256,20 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -258,22 +256,20 @@ struct BlockFmhaPipelineQRAsyncEx
constexpr auto warp_gemm_0 = Policy::template GetWarpGemm_0<Problem>(); constexpr auto warp_gemm_0 = Policy::template GetWarpGemm_0<Problem>();
constexpr auto warp_gemm_1 = Policy::template GetWarpGemm_1<Problem>(); constexpr auto warp_gemm_1 = Policy::template GetWarpGemm_1<Problem>();
auto gemm_0 = [&](){ auto gemm_0 = [&]() {
constexpr index_t total_repeats = Repeat_M0 * Repeat_N0 * Repeat_K0; constexpr index_t total_repeats = Repeat_M0 * Repeat_N0 * Repeat_K0;
// n*k*m, more relaxed ds_read // n*k*m, more relaxed ds_read
static_for<0, total_repeats, 1>{}( static_for<0, total_repeats, 1>{}([&](auto i_r) {
[&](auto i_r){ constexpr index_t i_m = i_r % Repeat_M0;
constexpr index_t i_m = i_r % Repeat_M0; constexpr index_t i_k = (i_r / Repeat_M0) % Repeat_K0;
constexpr index_t i_k = (i_r / Repeat_M0) % Repeat_K0; constexpr index_t i_n = i_r / (Repeat_M0 * Repeat_K0);
constexpr index_t i_n = i_r / (Repeat_M0 * Repeat_K0); });
}
);
}; };
auto q_dram_window = make_tile_window_raw(q_dram_block_window_tmp.get_bottom_tensor_view(), auto q_dram_window = make_tile_window_raw(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(), q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeGlobalDesc_Q<Problem>()); Policy::template MakeGlobalDesc_Q<Problem>());
// TODO: we use async Copy for K, which is inline asm // TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well // a side effect is we have to use inline asm for q as well
...@@ -285,7 +281,8 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -285,7 +281,8 @@ struct BlockFmhaPipelineQRAsyncEx
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(Policy::template MakeBlockGemmAccTile_0<Problem>()); using SaccBlockTileType = decltype(Policy::template MakeBlockGemmAccTile_0<Problem>());
auto s_accs = generate_tuple([&](auto) { return SaccBlockTileType{}; }, number<UnrollStages>{}); auto s_accs =
generate_tuple([&](auto) { return SaccBlockTileType{}; }, number<UnrollStages>{});
// infer Sacc, S, P, M, L, Oacc type // infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_accs)); using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_accs));
...@@ -296,9 +293,10 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -296,9 +293,10 @@ struct BlockFmhaPipelineQRAsyncEx
using OaccBlockTileType = decltype(Policy::template MakeBlockGemmAccTile_1<Problem>()); using OaccBlockTileType = decltype(Policy::template MakeBlockGemmAccTile_1<Problem>());
// init Oacc, M, L // init Oacc, M, L
auto o_accs = generate_tuple([&](auto) { return OaccBlockTileType{}; }, number<UnrollStages>{}); auto o_accs =
auto ms = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{}); generate_tuple([&](auto) { return OaccBlockTileType{}; }, number<UnrollStages>{});
auto ls = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{}); auto ms = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{});
auto ls = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{});
static_for<0, UnrollStages, 1>{}([&](auto i) { static_for<0, UnrollStages, 1>{}([&](auto i) {
clear_tile(o_accs(i)); clear_tile(o_accs(i));
......
...@@ -105,16 +105,16 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -105,16 +105,16 @@ struct BlockFmhaPipelineQRAsyncEx
CK_TILE_HOST_DEVICE static constexpr auto MakeBlockGemmAccTile_0() CK_TILE_HOST_DEVICE static constexpr auto MakeBlockGemmAccTile_0()
{ {
using AccWarpDescEnc_ = typename decltype(GetWarpGemm_0())::CWarpDstrEncoding; using AccWarpDescEnc_ = typename decltype(GetWarpGemm_0())::CWarpDstrEncoding;
using BlockTile_ = sequence<Problem::BlockFmhaShape::Block_M0, Problem::BlockFmhaShape::Block_N0>; using BlockTile_ =
using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M0, Problem::BlockFmhaShape::BlockWarps_N0>; sequence<Problem::BlockFmhaShape::Block_M0, Problem::BlockFmhaShape::Block_N0>;
using WarpTile_ = sequence<Problem::BlockFmhaShape::Warp_M0, Problem::BlockFmhaShape::Warp_N0>; using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M0,
constexpr auto enc = make_block_gemm_acc_enc< Problem::BlockFmhaShape::BlockWarps_N0>;
AccWarpDescEnc_, using WarpTile_ =
BlockTile_, sequence<Problem::BlockFmhaShape::Warp_M0, Problem::BlockFmhaShape::Warp_N0>;
BlockWarps_, constexpr auto enc =
WarpTile_>(); make_block_gemm_acc_enc<AccWarpDescEnc_, BlockTile_, BlockWarps_, WarpTile_>();
constexpr auto dstr = make_static_tile_distribution(enc); constexpr auto dstr = make_static_tile_distribution(enc);
auto t = make_static_distributed_tensor<typename Problem::SaccDataType>(dstr); auto t = make_static_distributed_tensor<typename Problem::SaccDataType>(dstr);
return t; return t;
} }
...@@ -443,8 +443,10 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -443,8 +443,10 @@ struct BlockFmhaPipelineQRAsyncEx
{ {
if constexpr(Problem::kHasDropout) if constexpr(Problem::kHasDropout)
{ {
constexpr index_t kMPerStep = Problem::BlockFmhaShape::BlockWarps_M0 * Problem::BlockFmhaShape::Warp_M0; constexpr index_t kMPerStep =
constexpr index_t kNPerStep = Problem::BlockFmhaShape::BlockWarps_N0 * Problem::BlockFmhaShape::Warp_N0; Problem::BlockFmhaShape::BlockWarps_M0 * Problem::BlockFmhaShape::Warp_M0;
constexpr index_t kNPerStep =
Problem::BlockFmhaShape::BlockWarps_N0 * Problem::BlockFmhaShape::Warp_N0;
return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t); return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t);
} }
...@@ -612,16 +614,16 @@ struct BlockFmhaPipelineQRAsyncEx ...@@ -612,16 +614,16 @@ struct BlockFmhaPipelineQRAsyncEx
CK_TILE_HOST_DEVICE static constexpr auto MakeBlockGemmAccTile_1() CK_TILE_HOST_DEVICE static constexpr auto MakeBlockGemmAccTile_1()
{ {
using AccWarpDescEnc_ = typename decltype(GetWarpGemm_1())::CWarpDstrEncoding; using AccWarpDescEnc_ = typename decltype(GetWarpGemm_1())::CWarpDstrEncoding;
using BlockTile_ = sequence<Problem::BlockFmhaShape::Block_M1, Problem::BlockFmhaShape::Block_N1>; using BlockTile_ =
using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M1, Problem::BlockFmhaShape::BlockWarps_N1>; sequence<Problem::BlockFmhaShape::Block_M1, Problem::BlockFmhaShape::Block_N1>;
using WarpTile_ = sequence<Problem::BlockFmhaShape::Warp_M1, Problem::BlockFmhaShape::Warp_N1>; using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M1,
constexpr auto enc = make_block_gemm_acc_enc< Problem::BlockFmhaShape::BlockWarps_N1>;
AccWarpDescEnc_, using WarpTile_ =
BlockTile_, sequence<Problem::BlockFmhaShape::Warp_M1, Problem::BlockFmhaShape::Warp_N1>;
BlockWarps_, constexpr auto enc =
WarpTile_>(); make_block_gemm_acc_enc<AccWarpDescEnc_, BlockTile_, BlockWarps_, WarpTile_>();
constexpr auto dstr = make_static_tile_distribution(enc); constexpr auto dstr = make_static_tile_distribution(enc);
auto t = make_static_distributed_tensor<typename Problem::OaccDataType>(dstr); auto t = make_static_distributed_tensor<typename Problem::OaccDataType>(dstr);
return t; return t;
} }
}; };
......
...@@ -43,15 +43,15 @@ struct TileFmhaShape ...@@ -43,15 +43,15 @@ struct TileFmhaShape
ck_tile::tensor_layout::gemm::ColumnMajor>; ck_tile::tensor_layout::gemm::ColumnMajor>;
// gemm-0 shapes TODO: naming? // gemm-0 shapes TODO: naming?
static constexpr index_t Block_M0 = kM0; static constexpr index_t Block_M0 = kM0;
static constexpr index_t Block_N0 = kN0; static constexpr index_t Block_N0 = kN0;
static constexpr index_t Block_K0 = kK0; static constexpr index_t Block_K0 = kK0;
static constexpr index_t BlockWarps_M0 = Gemm0BlockWarps::at(number<0>{}); static constexpr index_t BlockWarps_M0 = Gemm0BlockWarps::at(number<0>{});
static constexpr index_t BlockWarps_N0 = Gemm0BlockWarps::at(number<1>{}); static constexpr index_t BlockWarps_N0 = Gemm0BlockWarps::at(number<1>{});
static constexpr index_t BlockWarps_K0 = Gemm0BlockWarps::at(number<2>{}); static constexpr index_t BlockWarps_K0 = Gemm0BlockWarps::at(number<2>{});
static constexpr index_t Warps_M0 = Gemm0WarpTile::at(number<0>{}); static constexpr index_t Warps_M0 = Gemm0WarpTile::at(number<0>{});
static constexpr index_t Warps_N0 = Gemm0WarpTile::at(number<1>{}); static constexpr index_t Warps_N0 = Gemm0WarpTile::at(number<1>{});
static constexpr index_t Warps_K0 = Gemm0WarpTile::at(number<2>{}); static constexpr index_t Warps_K0 = Gemm0WarpTile::at(number<2>{});
static_assert(Block_M0 % (BlockWarps_M0 * Warps_M0) == 0); static_assert(Block_M0 % (BlockWarps_M0 * Warps_M0) == 0);
static_assert(Block_N0 % (BlockWarps_N0 * Warps_N0) == 0); static_assert(Block_N0 % (BlockWarps_N0 * Warps_N0) == 0);
static_assert(Block_K0 % (BlockWarps_K0 * Warps_K0) == 0); static_assert(Block_K0 % (BlockWarps_K0 * Warps_K0) == 0);
...@@ -59,15 +59,15 @@ struct TileFmhaShape ...@@ -59,15 +59,15 @@ struct TileFmhaShape
static constexpr index_t Repeat_N0 = Block_N0 / (BlockWarps_N0 * Warps_N0); static constexpr index_t Repeat_N0 = Block_N0 / (BlockWarps_N0 * Warps_N0);
static constexpr index_t Repeat_K0 = Block_K0 / (BlockWarps_K0 * Warps_K0); static constexpr index_t Repeat_K0 = Block_K0 / (BlockWarps_K0 * Warps_K0);
static constexpr index_t Block_M1 = kM0; static constexpr index_t Block_M1 = kM0;
static constexpr index_t Block_N1 = kN1; static constexpr index_t Block_N1 = kN1;
static constexpr index_t Block_K1 = kK1; static constexpr index_t Block_K1 = kK1;
static constexpr index_t BlockWarps_M1 = Gemm1BlockWarps::at(number<0>{}); static constexpr index_t BlockWarps_M1 = Gemm1BlockWarps::at(number<0>{});
static constexpr index_t BlockWarps_N1 = Gemm1BlockWarps::at(number<1>{}); static constexpr index_t BlockWarps_N1 = Gemm1BlockWarps::at(number<1>{});
static constexpr index_t BlockWarps_K1 = Gemm1BlockWarps::at(number<2>{}); static constexpr index_t BlockWarps_K1 = Gemm1BlockWarps::at(number<2>{});
static constexpr index_t Warps_M1 = Gemm1WarpTile::at(number<0>{}); static constexpr index_t Warps_M1 = Gemm1WarpTile::at(number<0>{});
static constexpr index_t Warps_N1 = Gemm1WarpTile::at(number<1>{}); static constexpr index_t Warps_N1 = Gemm1WarpTile::at(number<1>{});
static constexpr index_t Warps_K1 = Gemm1WarpTile::at(number<2>{}); static constexpr index_t Warps_K1 = Gemm1WarpTile::at(number<2>{});
static_assert(Block_M1 % (BlockWarps_M1 * Warps_M1) == 0); static_assert(Block_M1 % (BlockWarps_M1 * Warps_M1) == 0);
static_assert(Block_N1 % (BlockWarps_N1 * Warps_N1) == 0); static_assert(Block_N1 % (BlockWarps_N1 * Warps_N1) == 0);
static_assert(Block_K1 % (BlockWarps_K1 * Warps_K1) == 0); static_assert(Block_K1 % (BlockWarps_K1 * Warps_K1) == 0);
......
...@@ -7,10 +7,10 @@ ...@@ -7,10 +7,10 @@
namespace ck_tile { namespace ck_tile {
template<typename AccWarpDescEnc, template <typename AccWarpDescEnc,
typename BlockTile, // seq<M, N> typename BlockTile, // seq<M, N>
typename BlockWarps, typename BlockWarps,
typename WarpTile> typename WarpTile>
CK_TILE_DEVICE_HOST constexpr auto make_block_gemm_acc_enc() CK_TILE_DEVICE_HOST constexpr auto make_block_gemm_acc_enc()
{ {
constexpr index_t Block_M = BlockTile::at(number<0>{}); constexpr index_t Block_M = BlockTile::at(number<0>{});
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp" #include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -56,7 +56,7 @@ struct TopkSoftmaxWarpPerRowPipeline ...@@ -56,7 +56,7 @@ struct TopkSoftmaxWarpPerRowPipeline
{ {
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW #if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
auto x = load_tile_raw(inp_win, bool_constant<true>{}, bool_constant<true>{}); auto x = load_tile_raw(inp_win, number<-1>{}, bool_constant<true>{}, bool_constant<true>{});
buffer_load_fence(number<0>{}); buffer_load_fence(number<0>{});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
#else #else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment