Commit f79f727c authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

[CK TILE] Refactor GemmKernel - references fix

parent 799cde32
......@@ -160,7 +160,7 @@ struct GemmKernel
CDataType* c_ptr,
const GemmKargs& kargs) const
{
auto a_tensor_view = [&]() {
auto&& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
......@@ -181,7 +181,7 @@ struct GemmKernel
}
}();
auto b_tensor_view = [&]() {
auto&& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
......@@ -202,7 +202,7 @@ struct GemmKernel
}
}();
auto c_tensor_view = [&]() {
auto&& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
......@@ -227,10 +227,10 @@ struct GemmKernel
}
template <typename TensorView>
CK_TILE_DEVICE auto MakeGemmPadViews(TensorView& views) const
CK_TILE_DEVICE auto MakeGemmPadViews(TensorView&& views) const
{
auto a_pad_view = [&]() {
auto a_tensor_view = views.at(I0);
auto&& a_pad_view = [&]() {
auto&& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
......@@ -247,8 +247,8 @@ struct GemmKernel
}
}();
auto b_pad_view = [&]() {
auto b_tensor_view = views.at(I1);
auto&& b_pad_view = [&]() {
auto&& b_tensor_view = views.at(I1);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(
......@@ -265,8 +265,8 @@ struct GemmKernel
}
}();
auto c_pad_view = [&]() {
auto c_tensor_view = views.at(I2);
auto&& c_pad_view = [&]() {
auto&& c_tensor_view = views.at(I2);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
......@@ -288,22 +288,22 @@ struct GemmKernel
template <typename PadView>
CK_TILE_DEVICE auto
MakeGemmTileWindows(PadView& views, const index_t i_m, const index_t i_n) const
MakeGemmTileWindows(PadView&& views, const index_t i_m, const index_t i_n) const
{
auto a_pad_view = views.at(I0);
auto a_block_window = make_tile_window(
auto&& a_pad_view = views.at(I0);
auto&& a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
auto b_pad_view = views.at(I1);
auto b_block_window = make_tile_window(
auto&& b_pad_view = views.at(I1);
auto&& b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
auto c_pad_view = views.at(I2);
const auto c_block_window = make_tile_window(
auto&& c_pad_view = views.at(I2);
auto&& c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
......@@ -319,10 +319,9 @@ struct GemmKernel
const index_t block_idx_n) const
{
// Convert pointers to tensor views
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
const auto& gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
auto&& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs);
auto&& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto&& gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment