"git@developer.sourcefind.cn:gaoqiong/yaml-cpp.git" did not exist on "2dd1cf4596c392d7c35f5764877d88a344303846"
Commit f79f727c authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

[CK TILE] Refactor GemmKernel - references fix

parent 799cde32
...@@ -160,7 +160,7 @@ struct GemmKernel ...@@ -160,7 +160,7 @@ struct GemmKernel
CDataType* c_ptr, CDataType* c_ptr,
const GemmKargs& kargs) const const GemmKargs& kargs) const
{ {
auto a_tensor_view = [&]() { auto&& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
...@@ -181,7 +181,7 @@ struct GemmKernel ...@@ -181,7 +181,7 @@ struct GemmKernel
} }
}(); }();
auto b_tensor_view = [&]() { auto&& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
...@@ -202,7 +202,7 @@ struct GemmKernel ...@@ -202,7 +202,7 @@ struct GemmKernel
} }
}(); }();
auto c_tensor_view = [&]() { auto&& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
...@@ -227,10 +227,10 @@ struct GemmKernel ...@@ -227,10 +227,10 @@ struct GemmKernel
} }
template <typename TensorView> template <typename TensorView>
CK_TILE_DEVICE auto MakeGemmPadViews(TensorView& views) const CK_TILE_DEVICE auto MakeGemmPadViews(TensorView&& views) const
{ {
auto a_pad_view = [&]() { auto&& a_pad_view = [&]() {
auto a_tensor_view = views.at(I0); auto&& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(
...@@ -247,8 +247,8 @@ struct GemmKernel ...@@ -247,8 +247,8 @@ struct GemmKernel
} }
}(); }();
auto b_pad_view = [&]() { auto&& b_pad_view = [&]() {
auto b_tensor_view = views.at(I1); auto&& b_tensor_view = views.at(I1);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(
...@@ -265,8 +265,8 @@ struct GemmKernel ...@@ -265,8 +265,8 @@ struct GemmKernel
} }
}(); }();
auto c_pad_view = [&]() { auto&& c_pad_view = [&]() {
auto c_tensor_view = views.at(I2); auto&& c_tensor_view = views.at(I2);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
return pad_tensor_view( return pad_tensor_view(
...@@ -288,22 +288,22 @@ struct GemmKernel ...@@ -288,22 +288,22 @@ struct GemmKernel
template <typename PadView> template <typename PadView>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
MakeGemmTileWindows(PadView& views, const index_t i_m, const index_t i_n) const MakeGemmTileWindows(PadView&& views, const index_t i_m, const index_t i_n) const
{ {
auto a_pad_view = views.at(I0); auto&& a_pad_view = views.at(I0);
auto a_block_window = make_tile_window( auto&& a_block_window = make_tile_window(
a_pad_view, a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0}); {i_m, 0});
auto b_pad_view = views.at(I1); auto&& b_pad_view = views.at(I1);
auto b_block_window = make_tile_window( auto&& b_block_window = make_tile_window(
b_pad_view, b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0}); {i_n, 0});
auto c_pad_view = views.at(I2); auto&& c_pad_view = views.at(I2);
const auto c_block_window = make_tile_window( auto&& c_block_window = make_tile_window(
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
...@@ -319,10 +319,9 @@ struct GemmKernel ...@@ -319,10 +319,9 @@ struct GemmKernel
const index_t block_idx_n) const const index_t block_idx_n) const
{ {
// Convert pointers to tensor views // Convert pointers to tensor views
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs); auto&& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto&& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
const auto& gemm_tile_windows = auto&& gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment