Commit 02f8c487 authored by carlushuang's avatar carlushuang
Browse files

add single issue api

parent bafb600b
......@@ -21,28 +21,32 @@ template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(bool_constant<oob_conditional_check>{});
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(bool_constant<oob_conditional_check>{});
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename T,
......@@ -50,6 +54,7 @@ template <typename T,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
......@@ -57,10 +62,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
tile_window.load_raw(
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename T,
......@@ -68,6 +75,7 @@ template <typename T,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
......@@ -75,10 +83,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
tile_window.load_raw(
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
// for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem
......@@ -89,6 +99,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto
async_load_tile(LdsTileWindow_&& lds_tile,
......@@ -96,9 +107,11 @@ async_load_tile(LdsTileWindow_&& lds_tile,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.async_load(lds_tile, bool_constant<oob_conditional_check>{});
return tile_window.async_load(
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_,
......@@ -106,15 +119,18 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.async_load(lds_tile, bool_constant<oob_conditional_check>{});
return tile_window.async_load(
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_,
......@@ -122,6 +138,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto
......@@ -130,11 +147,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
......@@ -142,6 +162,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
......@@ -149,27 +170,44 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
template <typename WindowLengths>
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&)
template <typename WindowLengths, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return null_tensor{};
}
template <typename T, typename WindowLengths>
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<WindowLengths>&)
template <typename T,
typename WindowLengths,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/,
const null_tile_window<WindowLengths>&,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
}
// TODO: this function requires some sub-fileds exist for the target tile window
template <typename TileWindow, bool oob_conditional_check = true, bool pre_nop = false>
template <typename TileWindow,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
......@@ -178,7 +216,7 @@ CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
auto t = make_static_distributed_tensor<DataType>(TileDstr{});
load_tile_raw(t, w, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
load_tile_raw(t, w, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return t;
}
......
......@@ -18,10 +18,12 @@ namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void
store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
......@@ -35,16 +37,18 @@ store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& t
tile_window_tmp.get_window_origin(),
tile_dstr);
tile_window.store(dstr_tensor);
tile_window.store(dstr_tensor, number<i_access>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
......@@ -58,63 +62,71 @@ store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_
tile_window_tmp.get_window_origin(),
tile_dstr);
tile_window.store_raw(dstr_tensor);
tile_window.store_raw(dstr_tensor, number<i_access>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void
store_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
tile_window.store(dstr_tensor);
tile_window.store(dstr_tensor, number<i_access>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
tile_window.store_raw(dstr_tensor);
tile_window.store_raw(dstr_tensor, number<i_access>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void store_tile(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
tile_window.store(dstr_tensor);
tile_window.store(dstr_tensor, number<i_access>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void store_tile_raw(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
tile_window.store_raw(dstr_tensor);
tile_window.store_raw(dstr_tensor, number<i_access>{});
}
} // namespace ck_tile
......@@ -18,6 +18,23 @@
namespace ck_tile {
// TODO: NumCoord no need anymore?
#define WINDOW_DISPATCH_ISSUE_2() \
if constexpr(i_access < 0) \
{ \
static_for<0, NumCoord, 1>{}([&](auto iCoord) { \
static_for<0, NumAccessPerCoord, 1>{}( \
[&](auto iCoordAccess) { issue(iCoord, iCoordAccess); }); \
}); \
} \
else \
{ \
static_assert(i_access < (NumCoord * NumAccessPerCoord)); \
constexpr auto iCoordAccess = number<i_access % NumAccessPerCoord>{}; \
constexpr auto iCoord = number<i_access / NumAccessPerCoord>{}; \
issue(iCoord, iCoordAccess); \
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
......@@ -283,8 +300,8 @@ struct tile_window_with_static_distribution
CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; }
template <bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
......@@ -296,12 +313,11 @@ struct tile_window_with_static_distribution
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto issue = [&](auto iCoord, auto iCoordAccess) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
......@@ -316,20 +332,17 @@ struct tile_window_with_static_distribution
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % Traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
......@@ -347,14 +360,19 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
};
WINDOW_DISPATCH_ISSUE_2();
return dst_tensor;
}
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
template <typename DstTile,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
......@@ -377,12 +395,11 @@ struct tile_window_with_static_distribution
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto issue = [&](auto iCoord, auto iCoordAccess) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
......@@ -393,8 +410,7 @@ struct tile_window_with_static_distribution
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % Traits::ScalarPerVector == 0);
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
......@@ -405,8 +421,7 @@ struct tile_window_with_static_distribution
pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm volatile(
""); // this is starting from rocm-6.2, but same sympton, reuse this flag
asm volatile(""); // this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
......@@ -419,17 +434,18 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
"scratch memory" ::);
#endif
};
WINDOW_DISPATCH_ISSUE_2();
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false>
template <typename LdsTileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
......@@ -470,12 +486,11 @@ struct tile_window_with_static_distribution
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto issue = [&](auto iCoord, auto iCoordAccess) {
// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
......@@ -501,12 +516,14 @@ struct tile_window_with_static_distribution
m0_inc_with_memory(size_per_issue);
}
});
});
};
WINDOW_DISPATCH_ISSUE_2();
}
template <typename LdsTileWindow_, bool oob_conditional_check = true>
template <typename LdsTileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
......@@ -544,12 +561,11 @@ struct tile_window_with_static_distribution
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto issue = [&](auto iCoord, auto iCoordAccess) {
// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// read from bottom tensor
......@@ -569,12 +585,13 @@ struct tile_window_with_static_distribution
smem += size_per_issue; // Note we manually increase the per-issue offset
}
});
});
};
WINDOW_DISPATCH_ISSUE_2();
}
template <bool oob_conditional_check = true>
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
......@@ -586,12 +603,11 @@ struct tile_window_with_static_distribution
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto issue = [&](auto iCoord, auto iCoordAccess) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
......@@ -604,13 +620,11 @@ struct tile_window_with_static_distribution
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
......@@ -620,10 +634,7 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{});
bottom_tensor_thread_coord, 0, vec_value, bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
......@@ -636,12 +647,13 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
};
WINDOW_DISPATCH_ISSUE_2();
}
CK_TILE_DEVICE void
store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor) const
template <index_t i_access = -1>
CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {}) const
{
using Traits = load_store_traits;
......@@ -652,12 +664,11 @@ struct tile_window_with_static_distribution
static constexpr bool oob_conditional_check = true;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto issue = [&](auto iCoord, auto iCoordAccess) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
......@@ -668,12 +679,10 @@ struct tile_window_with_static_distribution
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
......@@ -694,12 +703,14 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
};
WINDOW_DISPATCH_ISSUE_2();
}
template <bool oob_conditional_check = true>
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
......@@ -710,12 +721,11 @@ struct tile_window_with_static_distribution
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto issue = [&](auto iCoord, auto iCoordAccess) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
......@@ -727,13 +737,11 @@ struct tile_window_with_static_distribution
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
......@@ -741,10 +749,7 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{});
bottom_tensor_thread_coord, 0, vec_value, bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
......@@ -757,8 +762,9 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
};
WINDOW_DISPATCH_ISSUE_2();
}
// move thread's botom tensor coordiante
......@@ -857,6 +863,8 @@ struct tile_window_with_static_distribution
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
};
#undef WINDOW_DISPATCH_ISSUE_2
// TODO: use strategy
template <typename TensorView_,
typename WindowLengths_,
......
......@@ -18,6 +18,17 @@
namespace ck_tile {
#define WINDOW_DISPATCH_ISSUE() \
if constexpr(i_access < 0) \
{ \
static_for<0, NumAccess, 1>{}([&](auto ia) { issue(ia); }); \
} \
else \
{ \
static_assert(i_access < NumAccess); \
issue(number<i_access>{}); \
}
//
// This version of tile window will pre-cache offset/flags based on need
//
......@@ -443,8 +454,8 @@ struct tile_window_linear
CK_TILE_DEVICE constexpr auto get_num_access() const { return traits::NumAccess; }
template <bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
......@@ -453,9 +464,8 @@ struct tile_window_linear
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
......@@ -494,17 +504,22 @@ struct tile_window_linear
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
});
};
WINDOW_DISPATCH_ISSUE();
return dst_tensor;
}
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
template <typename DstTile,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
number<i_access> = {}, // negative means loop over all num_access
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
static constexpr index_t YElementSize =
......@@ -516,11 +531,10 @@ struct tile_window_linear
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && i_access == 0 &&
if constexpr(pre_nop && i_access_ == 0 &&
BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global)
return bool_constant<true>{};
......@@ -550,16 +564,18 @@ struct tile_window_linear
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm volatile(""); // this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
"scratch memory" ::);
#endif
};
WINDOW_DISPATCH_ISSUE();
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false>
template <typename LdsTileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
......@@ -600,10 +616,10 @@ struct tile_window_linear
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && i_access == 0)
if constexpr(pre_nop && i_access_ == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
......@@ -618,15 +634,18 @@ struct tile_window_linear
smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_);
// move thread coordinate
if constexpr(i_access != (NumAccess - 1))
if constexpr(i_access_ != (NumAccess - 1))
{
m0_inc_with_memory(size_per_issue);
}
});
};
WINDOW_DISPATCH_ISSUE();
}
template <typename LdsTileWindow_, bool oob_conditional_check = true>
template <typename LdsTileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
......@@ -667,8 +686,8 @@ struct tile_window_linear
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
......@@ -682,15 +701,18 @@ struct tile_window_linear
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(i_access != (NumAccess - 1))
if constexpr(i_access_ != (NumAccess - 1))
{
smem += size_per_issue; // Note we manually increase the per-issue offset
}
});
};
WINDOW_DISPATCH_ISSUE();
}
template <bool oob_conditional_check = true>
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
......@@ -700,8 +722,8 @@ struct tile_window_linear
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
......@@ -732,13 +754,15 @@ struct tile_window_linear
bottom_tensor_flag,
vec_value,
bool_constant<oob_conditional_check>{});
});
};
WINDOW_DISPATCH_ISSUE();
}
CK_TILE_DEVICE void
store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor) const
template <index_t i_access = -1>
CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
......@@ -746,8 +770,8 @@ struct tile_window_linear
static constexpr bool oob_conditional_check = true;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
......@@ -773,11 +797,14 @@ struct tile_window_linear
get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value);
});
};
WINDOW_DISPATCH_ISSUE();
}
template <bool oob_conditional_check = true>
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
......@@ -787,8 +814,8 @@ struct tile_window_linear
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
......@@ -820,7 +847,9 @@ struct tile_window_linear
bottom_tensor_flag,
vec_value,
bool_constant<oob_conditional_check>{});
});
};
WINDOW_DISPATCH_ISSUE();
}
// move thread's botom tensor coordiante
......@@ -920,6 +949,8 @@ struct tile_window_linear
array<bool, traits::NumAccess> cached_flags_;
};
#undef WINDOW_DISPATCH_ISSUE
namespace impl {
template <address_space_enum, index_t len_>
struct default_linear_bottom_dims_impl
......
......@@ -17,10 +17,12 @@ namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void
update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
......@@ -34,22 +36,24 @@ update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>&
tile_window_tmp.get_window_origin(),
tile_dstr);
tile_window.update(dstr_tensor);
tile_window.update(dstr_tensor, number<i_access>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
typename DataType_,
index_t i_access = -1>
CK_TILE_DEVICE void
update_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {})
{
tile_window.update(dstr_tensor);
tile_window.update(dstr_tensor, number<i_access>{});
}
} // namespace ck_tile
......@@ -33,8 +33,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
//#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp"
//#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp"
// #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp"
// #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
......
......@@ -205,7 +205,7 @@ struct BlockFmhaPipelineQRAsyncEx
"wrong!");
// K tile in LDS
auto k_lds_store = [&](){
auto k_lds_store = [&]() {
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
return generate_tuple(
[&](auto i_buf) {
......@@ -218,17 +218,16 @@ struct BlockFmhaPipelineQRAsyncEx
number<Policy::NumPrefetchK>{});
}();
auto k_lds_load = [&](){
auto k_lds_load = [&]() {
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr,
Policy::template MakeSmemLoadDesc_K<Problem>()),
Policy::template MakeSmemLoadDesc_K<Problem>().get_lengths(), {0, 0});
return make_tile_window(make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeSmemLoadDesc_K<Problem>()),
Policy::template MakeSmemLoadDesc_K<Problem>().get_lengths(),
{0, 0});
}();
// V tile in LDS
auto v_lds_store = [&](){
auto v_lds_store = [&]() {
auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr);
return generate_tuple(
[&](auto i_buf) {
......@@ -241,13 +240,12 @@ struct BlockFmhaPipelineQRAsyncEx
number<Policy::NumPrefetchV>{});
}();
auto v_lds_load = [&](){
auto v_lds_load = [&]() {
auto v_lds_ptr = reinterpret_cast<VDataType*>(smem_ptr);
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
v_lds_ptr,
Policy::template MakeSmemLoadDesc_V<Problem>()),
Policy::template MakeSmemLoadDesc_V<Problem>().get_lengths(), {0, 0});
return make_tile_window(make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeSmemLoadDesc_V<Problem>()),
Policy::template MakeSmemLoadDesc_V<Problem>().get_lengths(),
{0, 0});
}();
// reduction function for softmax
......@@ -258,16 +256,14 @@ struct BlockFmhaPipelineQRAsyncEx
constexpr auto warp_gemm_0 = Policy::template GetWarpGemm_0<Problem>();
constexpr auto warp_gemm_1 = Policy::template GetWarpGemm_1<Problem>();
auto gemm_0 = [&](){
auto gemm_0 = [&]() {
constexpr index_t total_repeats = Repeat_M0 * Repeat_N0 * Repeat_K0;
// n*k*m, more relaxed ds_read
static_for<0, total_repeats, 1>{}(
[&](auto i_r){
static_for<0, total_repeats, 1>{}([&](auto i_r) {
constexpr index_t i_m = i_r % Repeat_M0;
constexpr index_t i_k = (i_r / Repeat_M0) % Repeat_K0;
constexpr index_t i_n = i_r / (Repeat_M0 * Repeat_K0);
}
);
});
};
auto q_dram_window = make_tile_window_raw(q_dram_block_window_tmp.get_bottom_tensor_view(),
......@@ -285,7 +281,8 @@ struct BlockFmhaPipelineQRAsyncEx
__builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(Policy::template MakeBlockGemmAccTile_0<Problem>());
auto s_accs = generate_tuple([&](auto) { return SaccBlockTileType{}; }, number<UnrollStages>{});
auto s_accs =
generate_tuple([&](auto) { return SaccBlockTileType{}; }, number<UnrollStages>{});
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_accs));
......@@ -296,7 +293,8 @@ struct BlockFmhaPipelineQRAsyncEx
using OaccBlockTileType = decltype(Policy::template MakeBlockGemmAccTile_1<Problem>());
// init Oacc, M, L
auto o_accs = generate_tuple([&](auto) { return OaccBlockTileType{}; }, number<UnrollStages>{});
auto o_accs =
generate_tuple([&](auto) { return OaccBlockTileType{}; }, number<UnrollStages>{});
auto ms = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{});
auto ls = generate_tuple([&](auto) { return MLBlockTileType{}; }, number<UnrollStages>{});
......
......@@ -105,14 +105,14 @@ struct BlockFmhaPipelineQRAsyncEx
CK_TILE_HOST_DEVICE static constexpr auto MakeBlockGemmAccTile_0()
{
using AccWarpDescEnc_ = typename decltype(GetWarpGemm_0())::CWarpDstrEncoding;
using BlockTile_ = sequence<Problem::BlockFmhaShape::Block_M0, Problem::BlockFmhaShape::Block_N0>;
using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M0, Problem::BlockFmhaShape::BlockWarps_N0>;
using WarpTile_ = sequence<Problem::BlockFmhaShape::Warp_M0, Problem::BlockFmhaShape::Warp_N0>;
constexpr auto enc = make_block_gemm_acc_enc<
AccWarpDescEnc_,
BlockTile_,
BlockWarps_,
WarpTile_>();
using BlockTile_ =
sequence<Problem::BlockFmhaShape::Block_M0, Problem::BlockFmhaShape::Block_N0>;
using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M0,
Problem::BlockFmhaShape::BlockWarps_N0>;
using WarpTile_ =
sequence<Problem::BlockFmhaShape::Warp_M0, Problem::BlockFmhaShape::Warp_N0>;
constexpr auto enc =
make_block_gemm_acc_enc<AccWarpDescEnc_, BlockTile_, BlockWarps_, WarpTile_>();
constexpr auto dstr = make_static_tile_distribution(enc);
auto t = make_static_distributed_tensor<typename Problem::SaccDataType>(dstr);
return t;
......@@ -443,8 +443,10 @@ struct BlockFmhaPipelineQRAsyncEx
{
if constexpr(Problem::kHasDropout)
{
constexpr index_t kMPerStep = Problem::BlockFmhaShape::BlockWarps_M0 * Problem::BlockFmhaShape::Warp_M0;
constexpr index_t kNPerStep = Problem::BlockFmhaShape::BlockWarps_N0 * Problem::BlockFmhaShape::Warp_N0;
constexpr index_t kMPerStep =
Problem::BlockFmhaShape::BlockWarps_M0 * Problem::BlockFmhaShape::Warp_M0;
constexpr index_t kNPerStep =
Problem::BlockFmhaShape::BlockWarps_N0 * Problem::BlockFmhaShape::Warp_N0;
return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t);
}
......@@ -612,14 +614,14 @@ struct BlockFmhaPipelineQRAsyncEx
CK_TILE_HOST_DEVICE static constexpr auto MakeBlockGemmAccTile_1()
{
using AccWarpDescEnc_ = typename decltype(GetWarpGemm_1())::CWarpDstrEncoding;
using BlockTile_ = sequence<Problem::BlockFmhaShape::Block_M1, Problem::BlockFmhaShape::Block_N1>;
using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M1, Problem::BlockFmhaShape::BlockWarps_N1>;
using WarpTile_ = sequence<Problem::BlockFmhaShape::Warp_M1, Problem::BlockFmhaShape::Warp_N1>;
constexpr auto enc = make_block_gemm_acc_enc<
AccWarpDescEnc_,
BlockTile_,
BlockWarps_,
WarpTile_>();
using BlockTile_ =
sequence<Problem::BlockFmhaShape::Block_M1, Problem::BlockFmhaShape::Block_N1>;
using BlockWarps_ = sequence<Problem::BlockFmhaShape::BlockWarps_M1,
Problem::BlockFmhaShape::BlockWarps_N1>;
using WarpTile_ =
sequence<Problem::BlockFmhaShape::Warp_M1, Problem::BlockFmhaShape::Warp_N1>;
constexpr auto enc =
make_block_gemm_acc_enc<AccWarpDescEnc_, BlockTile_, BlockWarps_, WarpTile_>();
constexpr auto dstr = make_static_tile_distribution(enc);
auto t = make_static_distributed_tensor<typename Problem::OaccDataType>(dstr);
return t;
......
......@@ -7,7 +7,7 @@
namespace ck_tile {
template<typename AccWarpDescEnc,
template <typename AccWarpDescEnc,
typename BlockTile, // seq<M, N>
typename BlockWarps,
typename WarpTile>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
......@@ -56,7 +56,7 @@ struct TopkSoftmaxWarpPerRowPipeline
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier(0);
auto x = load_tile_raw(inp_win, bool_constant<true>{}, bool_constant<true>{});
auto x = load_tile_raw(inp_win, number<-1>{}, bool_constant<true>{}, bool_constant<true>{});
buffer_load_fence(number<0>{});
__builtin_amdgcn_sched_barrier(0);
#else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment