Commit c54a975f authored by carlushuang's avatar carlushuang
Browse files

do not support single issue in old tile window

parent 6a25d081
...@@ -216,7 +216,8 @@ CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w, ...@@ -216,7 +216,8 @@ CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
auto t = make_static_distributed_tensor<DataType>(TileDstr{}); auto t = make_static_distributed_tensor<DataType>(TileDstr{});
load_tile_raw(t, w, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); load_tile_raw(
t, w, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return t; return t;
} }
......
...@@ -18,23 +18,8 @@ ...@@ -18,23 +18,8 @@
namespace ck_tile { namespace ck_tile {
// TODO: NumCoord no need anymore? // Note: this tile window do not support single issue
#define WINDOW_DISPATCH_ISSUE_2() \ // you need to use tile_window_linear structure for this purpose
if constexpr(i_access < 0) \
{ \
static_for<0, NumCoord, 1>{}([&](auto iCoord) { \
static_for<0, NumAccessPerCoord, 1>{}( \
[&](auto iCoordAccess) { issue(iCoord, iCoordAccess); }); \
}); \
} \
else \
{ \
static_assert(i_access < (NumCoord * NumAccessPerCoord)); \
constexpr auto iCoordAccess = number<i_access % NumAccessPerCoord>{}; \
constexpr auto iCoord = number<i_access / NumAccessPerCoord>{}; \
issue(iCoord, iCoordAccess); \
}
template <typename BottomTensorView_, template <typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename StaticTileDistribution_, typename StaticTileDistribution_,
...@@ -300,8 +285,9 @@ struct tile_window_with_static_distribution ...@@ -300,8 +285,9 @@ struct tile_window_with_static_distribution
CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; } CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; }
template <index_t i_access = -1, bool oob_conditional_check = true> template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{ {
using Traits = load_store_traits; using Traits = load_store_traits;
...@@ -313,66 +299,69 @@ struct tile_window_with_static_distribution ...@@ -313,66 +299,69 @@ struct tile_window_with_static_distribution
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr); auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto iCoord, auto iCoordAccess) { static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20 /// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...] // data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from bottom tensor // read from bottom tensor
const vector_t vec_value = const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>( get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{}); bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
#if 1 #if 1
// write into distributed tensor // write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array( constexpr auto idx_ys = generate_array(
[&](auto jj) { [&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
}, : idx_ys_start[jj];
number<NDimY>{}); },
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
constexpr index_t d =
dst_tensor.get_thread_buffer().template at<d>() = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()[j];
}); dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else #else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); constexpr index_t d =
static_assert(d % Traits::ScalarPerVector == 0); tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % Traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()( dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value); number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif #endif
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{ {
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step_static(iAccess); constexpr auto idx_diff_ys = SFC_Ys::get_forward_step_static(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate( constexpr auto idx_diff_ps_ys = container_concat(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
} idx_diff_ys);
};
WINDOW_DISPATCH_ISSUE_2(); move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
return dst_tensor; return dst_tensor;
} }
template <typename DstTile, template <typename DstTile,
index_t i_access = -1, index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
number<i_access> = {}, number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const bool_constant<pre_nop> = {}) const
{ {
...@@ -395,57 +384,59 @@ struct tile_window_with_static_distribution ...@@ -395,57 +384,59 @@ struct tile_window_with_static_distribution
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer()); auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto iCoord, auto iCoordAccess) { static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20 /// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto pre_nop_ = [&]() { constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) constexpr auto pre_nop_ = [&]() {
return bool_constant<true>{}; if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
else return bool_constant<true>{};
return bool_constant<false>{}; else
}(); return bool_constant<false>{};
}();
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); // data index [y0, y1, ...]
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
static_assert(d % Traits::ScalarPerVector == 0); constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>( static_assert(d % Traits::ScalarPerVector == 0);
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bottom_tensor_thread_coord, get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
0 /**/, dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bool_constant<oob_conditional_check>{}, bottom_tensor_thread_coord,
pre_nop_); 0 /**/,
bool_constant<oob_conditional_check>{},
pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \ #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm volatile(""); // this is starting from rocm-6.2, but same sympton, reuse this flag asm volatile(
""); // this is starting from rocm-6.2, but same sympton, reuse this flag
#endif #endif
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{ {
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate( constexpr auto idx_diff_ps_ys =
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
}
};
WINDOW_DISPATCH_ISSUE_2(); move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
} }
// TODO: currently async load only implemented in inline asm // TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
index_t i_access = -1, index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false> bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
number<i_access> = {}, number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const bool_constant<pre_nop> = {}) const
{ {
...@@ -486,44 +477,46 @@ struct tile_window_with_static_distribution ...@@ -486,44 +477,46 @@ struct tile_window_with_static_distribution
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto iCoord, auto iCoordAccess) { static_for<0, NumCoord, 1>{}([&](auto iCoord) {
// TODO: use structure binding (to be captured later) if compiled in C++20 /// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto pre_nop_ = [&]() { constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) constexpr auto pre_nop_ = [&]() {
return bool_constant<true>{}; if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
else return bool_constant<true>{};
return bool_constant<false>{}; else
}(); return bool_constant<false>{};
}();
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>( // read from bottom tensor
smem, bottom_tensor_thread_coord, 0, pre_nop_); get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, 0, pre_nop_);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) // move thread coordinate
{ if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); {
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate( constexpr auto idx_diff_ps_ys =
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
m0_inc_with_memory(size_per_issue); move_window_adaptor_and_bottom_tensor_thread_coordinate(
} window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
};
WINDOW_DISPATCH_ISSUE_2(); m0_inc_with_memory(size_per_issue);
}
});
});
} }
template <typename LdsTileWindow_, index_t i_access = -1, bool oob_conditional_check = true> template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
number<i_access> = {}, number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>; using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
...@@ -561,37 +554,38 @@ struct tile_window_with_static_distribution ...@@ -561,37 +554,38 @@ struct tile_window_with_static_distribution
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value; lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto iCoord, auto iCoordAccess) { static_for<0, NumCoord, 1>{}([&](auto iCoord) {
// TODO: use structure binding (to be captured later) if compiled in C++20 /// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// read from bottom tensor // read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>( get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{}); smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{ {
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys); container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate( move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
smem += size_per_issue; // Note we manually increase the per-issue offset smem += size_per_issue; // Note we manually increase the per-issue offset
} }
}; });
WINDOW_DISPATCH_ISSUE_2(); });
} }
template <index_t i_access = -1, bool oob_conditional_check = true> template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor, CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {}, number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
using Traits = load_store_traits; using Traits = load_store_traits;
...@@ -603,57 +597,62 @@ struct tile_window_with_static_distribution ...@@ -603,57 +597,62 @@ struct tile_window_with_static_distribution
constexpr auto tile_dstr = TileDstr{}; constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto iCoord, auto iCoordAccess) { static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...] // data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor // read from distributed tensor
// vector_type_t vec; // vector_type_t vec;
vector_t vec_value; vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array( constexpr auto idx_ys = generate_array(
[&](auto jj) { [&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
}, : idx_ys_start[jj];
number<NDimY>{}); },
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) = vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>(); dstr_tensor.get_thread_buffer().template at<d>();
}); });
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>(); // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor // write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>( get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, 0, vec_value, bool_constant<oob_conditional_check>{}); bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{});
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{ {
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys); container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate( move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
} }
}; });
WINDOW_DISPATCH_ISSUE_2(); });
} }
template <index_t i_access = -1> template <index_t i_access_unsupport_ = -1>
CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor, CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {}) const number<i_access_unsupport_> = {}) const
{ {
using Traits = load_store_traits; using Traits = load_store_traits;
...@@ -664,53 +663,55 @@ struct tile_window_with_static_distribution ...@@ -664,53 +663,55 @@ struct tile_window_with_static_distribution
static constexpr bool oob_conditional_check = true; static constexpr bool oob_conditional_check = true;
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto iCoord, auto iCoordAccess) { static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20 /// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); // data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor
vector_t vec_value; // read from distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { vector_t vec_value;
constexpr auto idx_ys = generate_array( static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
[&](auto jj) { constexpr auto idx_ys = generate_array(
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; [&](auto jj) {
}, return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
number<NDimY>{}); : idx_ys_start[jj];
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); },
vec_value.template get_as<DataType>()(j) = number<NDimY>{});
dstr_tensor.get_thread_buffer().template at<d>(); constexpr index_t d =
}); tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
// write into bottom tensor dstr_tensor.get_thread_buffer().template at<d>();
get_bottom_tensor_view() });
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, 0, vec_value); // write into bottom tensor
get_bottom_tensor_view()
// move thread coordinate .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) bottom_tensor_thread_coord, 0, vec_value);
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys); container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate( move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
} }
}; });
});
WINDOW_DISPATCH_ISSUE_2();
} }
template <index_t i_access = -1, bool oob_conditional_check = true> template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor, CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {}, number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
using Traits = load_store_traits; using Traits = load_store_traits;
...@@ -721,50 +722,55 @@ struct tile_window_with_static_distribution ...@@ -721,50 +722,55 @@ struct tile_window_with_static_distribution
constexpr auto tile_dstr = TileDstr{}; constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...] // loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto iCoord, auto iCoordAccess) { static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20 /// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...] // data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor // read from distributed tensor
vector_t vec_value; vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array( constexpr auto idx_ys = generate_array(
[&](auto jj) { [&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
}, : idx_ys_start[jj];
number<NDimY>{}); },
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) = vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>(); dstr_tensor.get_thread_buffer().template at<d>();
}); });
// write into bottom tensor // write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements<vector_t>( get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, 0, vec_value, bool_constant<oob_conditional_check>{}); bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{});
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{ {
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate( constexpr auto idx_diff_ps_ys =
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
}
};
WINDOW_DISPATCH_ISSUE_2(); move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
} }
// move thread's botom tensor coordiante // move thread's botom tensor coordiante
...@@ -863,8 +869,6 @@ struct tile_window_with_static_distribution ...@@ -863,8 +869,6 @@ struct tile_window_with_static_distribution
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_; array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
}; };
#undef WINDOW_DISPATCH_ISSUE_2
// TODO: use strategy // TODO: use strategy
template <typename TensorView_, template <typename TensorView_,
typename WindowLengths_, typename WindowLengths_,
......
...@@ -56,7 +56,8 @@ struct TopkSoftmaxWarpPerRowPipeline ...@@ -56,7 +56,8 @@ struct TopkSoftmaxWarpPerRowPipeline
{ {
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW #if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
auto x = load_tile_raw(inp_win, number<-1>{}, bool_constant<true>{}, bool_constant<true>{}); auto x =
load_tile_raw(inp_win, number<-1>{}, bool_constant<true>{}, bool_constant<true>{});
buffer_load_fence(number<0>{}); buffer_load_fence(number<0>{});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
#else #else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment