Commit 69b6d2ab authored by Adam Osewski's avatar Adam Osewski
Browse files

Enable reading on contiguous dimension in all layouts.

parent bd5008af
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -141,7 +140,9 @@ struct GemmKernel ...@@ -141,7 +140,9 @@ struct GemmKernel
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto idxs = TilePartitioner{}();
const auto i_m = idxs.at(number<0>{});
const auto i_n = idxs.at(number<1>{});
// options // options
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr); const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr); const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
...@@ -160,9 +161,9 @@ struct GemmKernel ...@@ -160,9 +161,9 @@ struct GemmKernel
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_start, a_start,
make_tuple(kargs.M, kargs.K), make_tuple(kargs.K, kargs.M),
make_tuple(1, kargs.stride_A), make_tuple(kargs.stride_A, 1),
number<1>{}, number<GemmPipeline::VectorSizeA>{},
number<1>{}); number<1>{});
} }
}(); }();
...@@ -172,9 +173,9 @@ struct GemmKernel ...@@ -172,9 +173,9 @@ struct GemmKernel
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_start, b_start,
make_tuple(kargs.N, kargs.K), make_tuple(kargs.K, kargs.N),
make_tuple(1, kargs.stride_B), make_tuple(kargs.stride_B, 1),
number<1>{}, number<GemmPipeline::VectorSizeB>{},
number<1>{}); number<1>{});
} }
else else
...@@ -200,16 +201,27 @@ struct GemmKernel ...@@ -200,16 +201,27 @@ struct GemmKernel
{ {
return pad_tensor_view( return pad_tensor_view(
a_tensor_view, a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kK>{}, number<TilePartitioner::kM>{}),
sequence<GemmPipeline::kPadM, false>{}); sequence<false, GemmPipeline::kPadM>{});
} }
}(); }();
// clang-format on
auto a_block_window = make_tile_window( auto a_block_window = [&]() {
a_pad_view, if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), {
{i_m, 0}); return make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
}
else
{
return make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kK>{}, number<TilePartitioner::kM>{}),
{0, i_m});
}
}();
auto b_pad_view = [&]() { auto b_pad_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
...@@ -223,15 +235,27 @@ struct GemmKernel ...@@ -223,15 +235,27 @@ struct GemmKernel
{ {
return pad_tensor_view( return pad_tensor_view(
b_tensor_view, b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kK>{}, number<TilePartitioner::kN>{}),
sequence<GemmPipeline::kPadN, false>{}); sequence<false, GemmPipeline::kPadN>{});
} }
}(); }();
auto b_block_window = make_tile_window( auto b_block_window = [&]() {
b_pad_view, if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), {
{i_n, 0}); return make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
}
else
{
return make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kK>{}, number<TilePartitioner::kN>{}),
{0, i_n});
}
}();
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -12,18 +13,21 @@ struct GemmPipelineAgBgCrImplBase ...@@ -12,18 +13,21 @@ struct GemmPipelineAgBgCrImplBase
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
template <typename DstBlockTile, typename SrcTileWindow> template <typename DstBlockTile, typename SrcTileWindow, typename DramTileWindowStep>
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile, CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const SrcTileWindow& dram_tile_window,
const DramTileWindowStep& dram_tile_window_step) const
{ {
load_tile(dst_block_tile, dram_tile_window); load_tile(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock}); move_tile_window(dram_tile_window, dram_tile_window_step);
} }
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction> template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
...@@ -60,19 +64,21 @@ struct GemmPipelineAgBgCrImplBase ...@@ -60,19 +64,21 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, CK_TILE_DEVICE auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view) const const ALdsTensorView& a_lds_block_view) const
{ {
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
// A DRAM tile window for load // A DRAM tile window for load
auto a_copy_dram_window = auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), make_tuple(YPerTile{}, XPerTile{}),
a_dram_block_window_tmp.get_window_origin(), a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>()); Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store // A LDS tile window for store
auto a_copy_lds_window = auto a_copy_lds_window = make_tile_window(
make_tile_window(a_lds_block_view, a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
auto a_lds_gemm_window = make_tile_window( auto a_lds_gemm_window = make_tile_window(
a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0}); a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
...@@ -86,18 +92,22 @@ struct GemmPipelineAgBgCrImplBase ...@@ -86,18 +92,22 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, CK_TILE_DEVICE auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorView& b_lds_block_view) const const BLdsTensorView& b_lds_block_view) const
{ {
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
auto b_copy_dram_window = auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), make_tuple(YPerTile{}, XPerTile{}),
b_dram_block_window_tmp.get_window_origin(), b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>()); Policy::template MakeBDramTileDistribution<Problem>());
// TODO: Do we really need those two tile windows???
// They're exactly same...
// B LDS tile window for store // B LDS tile window for store
auto b_copy_lds_window = auto b_copy_lds_window = make_tile_window(
make_tile_window(b_lds_block_view, b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
auto b_lds_gemm_window = make_tile_window( auto b_lds_gemm_window = make_tile_window(
b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0}); b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment