Commit 4cf45f1b authored by Adam Osewski's avatar Adam Osewski
Browse files

Add comment to load_tile_raw and change variable naming style.

parent 6ea43353
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -31,6 +31,15 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT
return tile_window.load(bool_constant<oob_conditional_check>{});
}
/**
* @brief Loads a tile of data using inline assembly.
*
* @note Bare in mind that loading data this way, you have to manually initialize your
* thread buffer and synchronize load afterwards in order to make sure it's done before
* using loaded data from registers
* @see `tile_window_with_static_distribution::init_raw()` and `buffer_view.hpp`
* @see `buffer_load_fence()`
*/
template <typename T,
typename BottomTensorView_,
typename WindowLengths_,
......
......@@ -121,7 +121,7 @@ struct GemmKernel
sequence < false,
GemmPipeline::kPadA ? true : false > {});
auto ABlockWindow = make_tile_window(
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
......@@ -132,7 +132,7 @@ struct GemmKernel
sequence < false,
GemmPipeline::kPadB ? true : false > {});
auto BBlockWindow = make_tile_window(
auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
......@@ -141,14 +141,12 @@ struct GemmKernel
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
auto c_block_tile =
GemmPipeline{}.template operator()(ABlockWindow, BBlockWindow, num_loop, smem_ptr);
{
}
// Run GEMM cooperatively by whole wokrgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
......@@ -179,6 +177,7 @@ struct GemmKernel
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
EpiloguePipeline{}(CBlockWindow, c_block_tile);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment