Commit 4cf45f1b authored by Adam Osewski's avatar Adam Osewski
Browse files

Add comment to load_tile_raw and change variable naming style.

parent 6ea43353
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -31,6 +31,15 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT ...@@ -31,6 +31,15 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT
return tile_window.load(bool_constant<oob_conditional_check>{}); return tile_window.load(bool_constant<oob_conditional_check>{});
} }
/**
* @brief Loads a tile of data using inline assembly.
*
* @note Bare in mind that loading data this way, you have to manually initialize your
* thread buffer and synchronize load afterwards in order to make sure it's done before
* using loaded data from registers
* @see `tile_window_with_static_distribution::init_raw()` and `buffer_view.hpp`
* @see `buffer_load_fence()`
*/
template <typename T, template <typename T,
typename BottomTensorView_, typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
......
...@@ -121,7 +121,7 @@ struct GemmKernel ...@@ -121,7 +121,7 @@ struct GemmKernel
sequence < false, sequence < false,
GemmPipeline::kPadA ? true : false > {}); GemmPipeline::kPadA ? true : false > {});
auto ABlockWindow = make_tile_window( auto a_block_window = make_tile_window(
a_pad_view, a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0}); {i_m, 0});
...@@ -132,7 +132,7 @@ struct GemmKernel ...@@ -132,7 +132,7 @@ struct GemmKernel
sequence < false, sequence < false,
GemmPipeline::kPadB ? true : false > {}); GemmPipeline::kPadB ? true : false > {});
auto BBlockWindow = make_tile_window( auto b_block_window = make_tile_window(
b_pad_view, b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0}); {i_n, 0});
...@@ -141,14 +141,12 @@ struct GemmKernel ...@@ -141,14 +141,12 @@ struct GemmKernel
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
auto c_block_tile =
GemmPipeline{}.template operator()(ABlockWindow, BBlockWindow, num_loop, smem_ptr);
{ // Run GEMM cooperatively by whole wokrgroup.
} auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr); CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() { auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
...@@ -179,6 +177,7 @@ struct GemmKernel ...@@ -179,6 +177,7 @@ struct GemmKernel
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
EpiloguePipeline{}(CBlockWindow, c_block_tile); EpiloguePipeline{}(CBlockWindow, c_block_tile);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment