Commit 02f8c487 authored by carlushuang's avatar carlushuang
Browse files

add single issue api

parent bafb600b
......@@ -21,28 +21,32 @@ template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(bool_constant<oob_conditional_check>{});
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(bool_constant<oob_conditional_check>{});
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename T,
......@@ -50,6 +54,7 @@ template <typename T,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
......@@ -57,10 +62,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
tile_window.load_raw(
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename T,
......@@ -68,6 +75,7 @@ template <typename T,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
......@@ -75,10 +83,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
tile_window.load_raw(
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
// for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem
......@@ -89,6 +99,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto
async_load_tile(LdsTileWindow_&& lds_tile,
......@@ -96,9 +107,11 @@ async_load_tile(LdsTileWindow_&& lds_tile,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.async_load(lds_tile, bool_constant<oob_conditional_check>{});
return tile_window.async_load(
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_,
......@@ -106,15 +119,18 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.async_load(lds_tile, bool_constant<oob_conditional_check>{});
return tile_window.async_load(
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_,
......@@ -122,6 +138,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto
......@@ -130,11 +147,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
......@@ -142,6 +162,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
......@@ -149,27 +170,44 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
template <typename WindowLengths>
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&)
template <typename WindowLengths, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return null_tensor{};
}
template <typename T, typename WindowLengths>
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<WindowLengths>&)
template <typename T,
typename WindowLengths,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/,
const null_tile_window<WindowLengths>&,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
}
// TODO: this function requires some sub-fileds exist for the target tile window
template <typename TileWindow, bool oob_conditional_check = true, bool pre_nop = false>
template <typename TileWindow,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
......@@ -178,7 +216,7 @@ CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
auto t = make_static_distributed_tensor<DataType>(TileDstr{});
load_tile_raw(t, w, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
load_tile_raw(t, w, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return t;
}
......
......@@ -18,10 +18,12 @@ namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void
store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
......@@ -35,16 +37,18 @@ store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& t
tile_window_tmp.get_window_origin(),
tile_dstr);
tile_window.store(dstr_tensor);
tile_window.store(dstr_tensor, number<i_access>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
......@@ -58,63 +62,71 @@ store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_
tile_window_tmp.get_window_origin(),
tile_dstr);
tile_window.store_raw(dstr_tensor);
tile_window.store_raw(dstr_tensor, number<i_access>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void
store_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
tile_window.store(dstr_tensor);
tile_window.store(dstr_tensor, number<i_access>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
tile_window.store_raw(dstr_tensor);
tile_window.store_raw(dstr_tensor, number<i_access>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void store_tile(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
tile_window.store(dstr_tensor);
tile_window.store(dstr_tensor, number<i_access>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void store_tile_raw(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
tile_window.store_raw(dstr_tensor);
tile_window.store_raw(dstr_tensor, number<i_access>{});
}
} // namespace ck_tile
This diff is collapsed.
......@@ -18,6 +18,17 @@
namespace ck_tile {
#define WINDOW_DISPATCH_ISSUE() \
if constexpr(i_access < 0) \
{ \
static_for<0, NumAccess, 1>{}([&](auto ia) { issue(ia); }); \
} \
else \
{ \
static_assert(i_access < NumAccess); \
issue(number<i_access>{}); \
}
//
// This version of tile window will pre-cache offset/flags based on need
//
......@@ -443,8 +454,8 @@ struct tile_window_linear
CK_TILE_DEVICE constexpr auto get_num_access() const { return traits::NumAccess; }
template <bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
......@@ -453,9 +464,8 @@ struct tile_window_linear
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
......@@ -494,17 +504,22 @@ struct tile_window_linear
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
});
};
WINDOW_DISPATCH_ISSUE();
return dst_tensor;
}
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
template <typename DstTile,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
number<i_access> = {}, // negative means loop over all num_access
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
static constexpr index_t YElementSize =
......@@ -516,11 +531,10 @@ struct tile_window_linear
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && i_access == 0 &&
if constexpr(pre_nop && i_access_ == 0 &&
BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global)
return bool_constant<true>{};
......@@ -550,16 +564,18 @@ struct tile_window_linear
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm volatile(""); // this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
"scratch memory" ::);
#endif
};
WINDOW_DISPATCH_ISSUE();
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false>
template <typename LdsTileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
......@@ -600,10 +616,10 @@ struct tile_window_linear
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && i_access == 0)
if constexpr(pre_nop && i_access_ == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
......@@ -618,15 +634,18 @@ struct tile_window_linear
smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_);
// move thread coordinate
if constexpr(i_access != (NumAccess - 1))
if constexpr(i_access_ != (NumAccess - 1))
{
m0_inc_with_memory(size_per_issue);
}
});
};
WINDOW_DISPATCH_ISSUE();
}
template <typename LdsTileWindow_, bool oob_conditional_check = true>
template <typename LdsTileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
......@@ -667,8 +686,8 @@ struct tile_window_linear
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
......@@ -682,15 +701,18 @@ struct tile_window_linear
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(i_access != (NumAccess - 1))
if constexpr(i_access_ != (NumAccess - 1))
{
smem += size_per_issue; // Note we manually increase the per-issue offset
}
});
};
WINDOW_DISPATCH_ISSUE();
}
template <bool oob_conditional_check = true>
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
......@@ -700,8 +722,8 @@ struct tile_window_linear
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
......@@ -732,13 +754,15 @@ struct tile_window_linear
bottom_tensor_flag,
vec_value,
bool_constant<oob_conditional_check>{});
});
};
WINDOW_DISPATCH_ISSUE();
}
CK_TILE_DEVICE void
store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor) const
template <index_t i_access = -1>
CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
......@@ -746,8 +770,8 @@ struct tile_window_linear
static constexpr bool oob_conditional_check = true;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
......@@ -773,11 +797,14 @@ struct tile_window_linear
get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value);
});
};
WINDOW_DISPATCH_ISSUE();
}
template <bool oob_conditional_check = true>
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
......@@ -787,8 +814,8 @@ struct tile_window_linear
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
......@@ -820,7 +847,9 @@ struct tile_window_linear
bottom_tensor_flag,
vec_value,
bool_constant<oob_conditional_check>{});
});
};
WINDOW_DISPATCH_ISSUE();
}
// move thread's botom tensor coordiante
......@@ -920,6 +949,8 @@ struct tile_window_linear
array<bool, traits::NumAccess> cached_flags_;
};
#undef WINDOW_DISPATCH_ISSUE
namespace impl {
template <address_space_enum, index_t len_>
struct default_linear_bottom_dims_impl
......
......@@ -17,10 +17,12 @@ namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void
update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
......@@ -34,22 +36,24 @@ update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>&
tile_window_tmp.get_window_origin(),
tile_dstr);
tile_window.update(dstr_tensor);
tile_window.update(dstr_tensor, number<i_access>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void
update_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
tile_window.update(dstr_tensor);
tile_window.update(dstr_tensor, number<i_access>{});
}
} // namespace ck_tile
......@@ -33,8 +33,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
//#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp"
//#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp"
// #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp"
// #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
......
......@@ -45,33 +45,33 @@ struct BlockFmhaPipelineQRAsyncEx
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
static constexpr index_t Block_M0 = BlockFmhaShape::Block_M0;
static constexpr index_t Block_N0 = BlockFmhaShape::Block_N0;
static constexpr index_t Block_K0 = BlockFmhaShape::Block_K0;
static constexpr index_t Block_M0 = BlockFmhaShape::Block_M0;
static constexpr index_t Block_N0 = BlockFmhaShape::Block_N0;
static constexpr index_t Block_K0 = BlockFmhaShape::Block_K0;
static constexpr index_t BlockWarps_M0 = BlockFmhaShape::BlockWarps_M0;
static constexpr index_t BlockWarps_N0 = BlockFmhaShape::BlockWarps_N0;
static constexpr index_t BlockWarps_K0 = BlockFmhaShape::BlockWarps_K0;
static constexpr index_t Warps_M0 = BlockFmhaShape::Warps_M0;
static constexpr index_t Warps_N0 = BlockFmhaShape::Warps_N0;
static constexpr index_t Warps_K0 = BlockFmhaShape::Warps_K0;
static constexpr index_t Repeat_M0 = BlockFmhaShape::Repeat_M0;
static constexpr index_t Repeat_N0 = BlockFmhaShape::Repeat_N0;
static constexpr index_t Repeat_K0 = BlockFmhaShape::Repeat_K0;
static constexpr index_t Block_M1 = BlockFmhaShape::Block_M1;
static constexpr index_t Block_N1 = BlockFmhaShape::Block_N1;
static constexpr index_t Block_K1 = BlockFmhaShape::Block_K1;
static constexpr index_t Warps_M0 = BlockFmhaShape::Warps_M0;
static constexpr index_t Warps_N0 = BlockFmhaShape::Warps_N0;
static constexpr index_t Warps_K0 = BlockFmhaShape::Warps_K0;
static constexpr index_t Repeat_M0 = BlockFmhaShape::Repeat_M0;
static constexpr index_t Repeat_N0 = BlockFmhaShape::Repeat_N0;
static constexpr index_t Repeat_K0 = BlockFmhaShape::Repeat_K0;
static constexpr index_t Block_M1 = BlockFmhaShape::Block_M1;
static constexpr index_t Block_N1 = BlockFmhaShape::Block_N1;
static constexpr index_t Block_K1 = BlockFmhaShape::Block_K1;
static constexpr index_t BlockWarps_M1 = BlockFmhaShape::BlockWarps_M1;
static constexpr index_t BlockWarps_N1 = BlockFmhaShape::BlockWarps_N1;
static constexpr index_t BlockWarps_K1 = BlockFmhaShape::BlockWarps_K1;
static constexpr index_t Warps_M1 = BlockFmhaShape::Warps_M1;
static constexpr index_t Warps_N1 = BlockFmhaShape::Warps_N1;
static constexpr index_t Warps_K1 = BlockFmhaShape::Warps_K1;
static constexpr index_t Repeat_M1 = BlockFmhaShape::Repeat_M1;
static constexpr index_t Repeat_N1 = BlockFmhaShape::Repeat_N1;
static constexpr index_t Repeat_K1 = BlockFmhaShape::Repeat_K1;
static constexpr index_t Warps_M1 = BlockFmhaShape::Warps_M1;
static constexpr index_t Warps_N1 = BlockFmhaShape::Warps_N1;
static constexpr index_t Warps_K1 = BlockFmhaShape::Warps_K1;
static constexpr index_t Repeat_M1 = BlockFmhaShape::Repeat_M1;
static constexpr index_t Repeat_N1 = BlockFmhaShape::Repeat_N1;
static constexpr index_t Repeat_K1 = BlockFmhaShape::Repeat_K1;
static constexpr index_t UnrollStages = 2; // pipeline unroll the gemm/softmax/gemm
static constexpr index_t UnrollStages = 2; // pipeline unroll the gemm/softmax/gemm
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
......@@ -205,49 +205,47 @@ struct BlockFmhaPipelineQRAsyncEx
"wrong!");
// K tile in LDS
auto k_lds_store = [&](){
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
auto k_lds_store = [&]() {
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
return generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeSmemStoreDesc_K<Problem>(i_buf)),
Policy::template MakeSmemStoreDesc_K<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumPrefetchK>{});
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeSmemStoreDesc_K<Problem>(i_buf)),
Policy::template MakeSmemStoreDesc_K<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumPrefetchK>{});
}();
auto k_lds_load = [&](){
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr,
Policy::template MakeSmemLoadDesc_K<Problem>()),
Policy::template MakeSmemLoadDesc_K<Problem>().get_lengths(), {0, 0});
auto k_lds_load = [&]() {
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
return make_tile_window(make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeSmemLoadDesc_K<Problem>()),
Policy::template MakeSmemLoadDesc_K<Problem>().get_lengths(),
{0, 0});
}();
// V tile in LDS
auto v_lds_store = [&](){
auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr);
auto v_lds_store = [&]() {
auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr);
return generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeSmemStoreDesc_V<Problem>(i_buf)),
Policy::template MakeSmemStoreDesc_V<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumPrefetchV>{});
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeSmemStoreDesc_V<Problem>(i_buf)),
Policy::template MakeSmemStoreDesc_V<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumPrefetchV>{});
}();
auto v_lds_load = [&](){
auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr);
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
v_lds_ptr,
Policy::template MakeSmemLoadDesc_V<Problem>()),
Policy::template MakeSmemLoadDesc_V<Problem>().get_lengths(), {0, 0});
auto v_lds_load = [&]() {
auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr);
return make_tile_window(make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeSmemLoadDesc_V<Problem>()),
Policy::template MakeSmemLoadDesc_V<Problem>().get_lengths(),
{0, 0});
}();
// reduction function for softmax
......@@ -258,22 +256,20 @@ struct BlockFmhaPipelineQRAsyncEx
constexpr auto warp_gemm_0 = Policy::template GetWarpGemm_0<Problem>();
constexpr auto warp_gemm_1 = Policy::template GetWarpGemm_1<Problem>();
auto gemm_0 = [&](){
auto gemm_0 = [&]() {
constexpr index_t total_repeats = Repeat_M0 * Repeat_N0 * Repeat_K0;
// n*k*m, more relaxed ds_read
static_for<0, total_repeats, 1>{}(
[&](auto i_r){
constexpr index_t i_m = i_r % Repeat_M0;
constexpr index_t i_k = (i_r / Repeat_M0) % Repeat_K0;
constexpr index_t i_n = i_r / (Repeat_M0 * Repeat_K0);
}
);
static_for<0, total_repeats, 1>{}([&](auto i_r) {
constexpr index_t i_m = i_r % Repeat_M0;
constexpr index_t i_k = (i_r / Repeat_M0) % Repeat_K0;
constexpr index_t i_n = i_r / (Repeat_M0 * Repeat_K0);
});
};
auto q_dram_window = make_tile_window_raw(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeGlobalDesc_Q<Problem>());
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeGlobalDesc_Q<Problem>());
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
......@@ -285,7 +281,8 @@ struct BlockFmhaPipelineQRAsyncEx
__builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(Policy::template MakeBlockGemmAccTile_0<Problem>());
auto s_accs = generate_tuple([&](auto) { return SaccBlockTileType{}; }, number<UnrollStages>{});
auto s_accs =
generate_tuple([&](auto) { return SaccBlockTileType{}; }, number<UnrollStages>{});
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_accs));
......@@ -296,9 +293,10 @@ struct BlockFmhaPipelineQRAsyncEx
using OaccBlockTileType = decltype(Policy::template MakeBlockGemmAccTile_1<Problem>());
// init Oacc, M, L
auto o_accs = generate_tuple([&](auto) { return OaccBlockTileType{}; }, number<UnrollStages>{});
auto ms = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{});
auto ls = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{});
auto o_accs =
generate_tuple([&](auto) { return OaccBlockTileType{}; }, number<UnrollStages>{});
auto ms = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{});
auto ls = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{});
static_for<0, UnrollStages, 1>{}([&](auto i) {
clear_tile(o_accs(i));
......
......@@ -105,16 +105,16 @@ struct BlockFmhaPipelineQRAsyncEx
CK_TILE_HOST_DEVICE static constexpr auto MakeBlockGemmAccTile_0()
{
using AccWarpDescEnc_ = typename decltype(GetWarpGemm_0())::CWarpDstrEncoding;
using BlockTile_ = sequence<Problem::BlockFmhaShape::Block_M0, Problem::BlockFmhaShape::Block_N0>;
using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M0, Problem::BlockFmhaShape::BlockWarps_N0>;
using WarpTile_ = sequence<Problem::BlockFmhaShape::Warp_M0, Problem::BlockFmhaShape::Warp_N0>;
constexpr auto enc = make_block_gemm_acc_enc<
AccWarpDescEnc_,
BlockTile_,
BlockWarps_,
WarpTile_>();
using BlockTile_ =
sequence<Problem::BlockFmhaShape::Block_M0, Problem::BlockFmhaShape::Block_N0>;
using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M0,
Problem::BlockFmhaShape::BlockWarps_N0>;
using WarpTile_ =
sequence<Problem::BlockFmhaShape::Warp_M0, Problem::BlockFmhaShape::Warp_N0>;
constexpr auto enc =
make_block_gemm_acc_enc<AccWarpDescEnc_, BlockTile_, BlockWarps_, WarpTile_>();
constexpr auto dstr = make_static_tile_distribution(enc);
auto t = make_static_distributed_tensor<typename Problem::SaccDataType>(dstr);
auto t = make_static_distributed_tensor<typename Problem::SaccDataType>(dstr);
return t;
}
......@@ -443,8 +443,10 @@ struct BlockFmhaPipelineQRAsyncEx
{
if constexpr(Problem::kHasDropout)
{
constexpr index_t kMPerStep = Problem::BlockFmhaShape::BlockWarps_M0 * Problem::BlockFmhaShape::Warp_M0;
constexpr index_t kNPerStep = Problem::BlockFmhaShape::BlockWarps_N0 * Problem::BlockFmhaShape::Warp_N0;
constexpr index_t kMPerStep =
Problem::BlockFmhaShape::BlockWarps_M0 * Problem::BlockFmhaShape::Warp_M0;
constexpr index_t kNPerStep =
Problem::BlockFmhaShape::BlockWarps_N0 * Problem::BlockFmhaShape::Warp_N0;
return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t);
}
......@@ -612,16 +614,16 @@ struct BlockFmhaPipelineQRAsyncEx
CK_TILE_HOST_DEVICE static constexpr auto MakeBlockGemmAccTile_1()
{
using AccWarpDescEnc_ = typename decltype(GetWarpGemm_1())::CWarpDstrEncoding;
using BlockTile_ = sequence<Problem::BlockFmhaShape::Block_M1, Problem::BlockFmhaShape::Block_N1>;
using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M1, Problem::BlockFmhaShape::BlockWarps_N1>;
using WarpTile_ = sequence<Problem::BlockFmhaShape::Warp_M1, Problem::BlockFmhaShape::Warp_N1>;
constexpr auto enc = make_block_gemm_acc_enc<
AccWarpDescEnc_,
BlockTile_,
BlockWarps_,
WarpTile_>();
using BlockTile_ =
sequence<Problem::BlockFmhaShape::Block_M1, Problem::BlockFmhaShape::Block_N1>;
using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M1,
Problem::BlockFmhaShape::BlockWarps_N1>;
using WarpTile_ =
sequence<Problem::BlockFmhaShape::Warp_M1, Problem::BlockFmhaShape::Warp_N1>;
constexpr auto enc =
make_block_gemm_acc_enc<AccWarpDescEnc_, BlockTile_, BlockWarps_, WarpTile_>();
constexpr auto dstr = make_static_tile_distribution(enc);
auto t = make_static_distributed_tensor<typename Problem::OaccDataType>(dstr);
auto t = make_static_distributed_tensor<typename Problem::OaccDataType>(dstr);
return t;
}
};
......
......@@ -43,15 +43,15 @@ struct TileFmhaShape
ck_tile::tensor_layout::gemm::ColumnMajor>;
// gemm-0 shapes TODO: naming?
static constexpr index_t Block_M0 = kM0;
static constexpr index_t Block_N0 = kN0;
static constexpr index_t Block_K0 = kK0;
static constexpr index_t Block_M0 = kM0;
static constexpr index_t Block_N0 = kN0;
static constexpr index_t Block_K0 = kK0;
static constexpr index_t BlockWarps_M0 = Gemm0BlockWarps::at(number<0>{});
static constexpr index_t BlockWarps_N0 = Gemm0BlockWarps::at(number<1>{});
static constexpr index_t BlockWarps_K0 = Gemm0BlockWarps::at(number<2>{});
static constexpr index_t Warps_M0 = Gemm0WarpTile::at(number<0>{});
static constexpr index_t Warps_N0 = Gemm0WarpTile::at(number<1>{});
static constexpr index_t Warps_K0 = Gemm0WarpTile::at(number<2>{});
static constexpr index_t Warps_M0 = Gemm0WarpTile::at(number<0>{});
static constexpr index_t Warps_N0 = Gemm0WarpTile::at(number<1>{});
static constexpr index_t Warps_K0 = Gemm0WarpTile::at(number<2>{});
static_assert(Block_M0 % (BlockWarps_M0 * Warps_M0) == 0);
static_assert(Block_N0 % (BlockWarps_N0 * Warps_N0) == 0);
static_assert(Block_K0 % (BlockWarps_K0 * Warps_K0) == 0);
......@@ -59,15 +59,15 @@ struct TileFmhaShape
static constexpr index_t Repeat_N0 = Block_N0 / (BlockWarps_N0 * Warps_N0);
static constexpr index_t Repeat_K0 = Block_K0 / (BlockWarps_K0 * Warps_K0);
static constexpr index_t Block_M1 = kM0;
static constexpr index_t Block_N1 = kN1;
static constexpr index_t Block_K1 = kK1;
static constexpr index_t Block_M1 = kM0;
static constexpr index_t Block_N1 = kN1;
static constexpr index_t Block_K1 = kK1;
static constexpr index_t BlockWarps_M1 = Gemm1BlockWarps::at(number<0>{});
static constexpr index_t BlockWarps_N1 = Gemm1BlockWarps::at(number<1>{});
static constexpr index_t BlockWarps_K1 = Gemm1BlockWarps::at(number<2>{});
static constexpr index_t Warps_M1 = Gemm1WarpTile::at(number<0>{});
static constexpr index_t Warps_N1 = Gemm1WarpTile::at(number<1>{});
static constexpr index_t Warps_K1 = Gemm1WarpTile::at(number<2>{});
static constexpr index_t Warps_M1 = Gemm1WarpTile::at(number<0>{});
static constexpr index_t Warps_N1 = Gemm1WarpTile::at(number<1>{});
static constexpr index_t Warps_K1 = Gemm1WarpTile::at(number<2>{});
static_assert(Block_M1 % (BlockWarps_M1 * Warps_M1) == 0);
static_assert(Block_N1 % (BlockWarps_N1 * Warps_N1) == 0);
static_assert(Block_K1 % (BlockWarps_K1 * Warps_K1) == 0);
......
......@@ -7,10 +7,10 @@
namespace ck_tile {
template<typename AccWarpDescEnc,
typename BlockTile, // seq<M, N>
typename BlockWarps,
typename WarpTile>
template <typename AccWarpDescEnc,
typename BlockTile, // seq<M, N>
typename BlockWarps,
typename WarpTile>
CK_TILE_DEVICE_HOST constexpr auto make_block_gemm_acc_enc()
{
constexpr index_t Block_M = BlockTile::at(number<0>{});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
......@@ -56,7 +56,7 @@ struct TopkSoftmaxWarpPerRowPipeline
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier(0);
auto x = load_tile_raw(inp_win, bool_constant<true>{}, bool_constant<true>{});
auto x = load_tile_raw(inp_win, number<-1>{}, bool_constant<true>{}, bool_constant<true>{});
buffer_load_fence(number<0>{});
__builtin_amdgcn_sched_barrier(0);
#else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment