Commit 896f8b4c authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

add gemm_api and instances

parent 73a076ee
# add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
# add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp)
function (add_gemm_example TARGET_NAME MAIN_SRC)
message("adding ${TARGET_NAME}")
# not using add_example_executable() to add target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
foreach(source IN LISTS ARGN)
list(APPEND INSTANCE_SRCS ${source})
endforeach()
target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS})
set(COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
endfunction(add_gemm_example TARGET_NAME MAIN_SRC)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_gemm_example(tile_example_gemm_universal universal_gemm.cpp ${INSTANCE_SRCS})
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
......@@ -9,13 +9,10 @@
#include <string>
#include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
......@@ -103,6 +100,30 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
return ave_time;
}
float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Row, Row, Row>(args, s);
}
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Row, Col, Row>(args, s);
}
else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Col, Row, Row>(args, s);
}
else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Col, Col, Row>(args, s);
}
else
{
throw std::runtime_error("Wrong! Layouts not supported!\n");
}
}
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/epilogue.hpp"
template <typename DataType>
struct GemmBasicTypeConfig;
......@@ -51,6 +52,59 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
struct gemm_traits
{
std::string data_type;
bool is_a_rowmajor;
bool is_b_rowmajor;
bool is_c_rowmajor;
};
template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
ck_tile::index_t M_Tile_,
ck_tile::index_t N_Tile_,
ck_tile::index_t K_Tile_,
ck_tile::index_t M_Warp_,
ck_tile::index_t N_Warp_,
ck_tile::index_t K_Warp_,
ck_tile::index_t M_Warp_Tile_,
ck_tile::index_t N_Warp_Tile_,
ck_tile::index_t K_Warp_Tile_,
bool kPadM_,
bool kPadN_,
bool kPadK_>
struct gemm_traits_
{
using ADataType = ck_tile::remove_cvref_t<ADataType_>;
using BDataType = ck_tile::remove_cvref_t<BDataType_>;
using AccDataType = ck_tile::remove_cvref_t<AccDataType_>;
using CDataType = ck_tile::remove_cvref_t<CDataType_>;
using ALayout = ck_tile::remove_cvref_t<ALayout_>;
using BLayout = ck_tile::remove_cvref_t<BLayout_>;
using CLayout = ck_tile::remove_cvref_t<CLayout_>;
static constexpr ck_tile::index_t M_Tile = M_Tile_;
static constexpr ck_tile::index_t N_Tile = N_Tile_;
static constexpr ck_tile::index_t K_Tile = K_Tile_;
static constexpr ck_tile::index_t M_Warp = M_Warp_;
static constexpr ck_tile::index_t N_Warp = N_Warp_;
static constexpr ck_tile::index_t K_Warp = K_Warp_;
static constexpr ck_tile::index_t M_Warp_Tile = M_Warp_Tile_;
static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_;
static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
......@@ -75,4 +129,9 @@ auto create_args(int argc, char* argv[])
}
// host API
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
template <typename Traits_>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
float gemm(const gemm_traits& traits,
const ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s);
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_basic.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP32 = float;
using FP16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
ck_tile::index_t M_Tile_,
ck_tile::index_t N_Tile_,
ck_tile::index_t K_Tile_,
ck_tile::index_t M_Warp_,
ck_tile::index_t N_Warp_,
ck_tile::index_t K_Warp_,
ck_tile::index_t M_Warp_Tile_,
ck_tile::index_t N_Warp_Tile_,
ck_tile::index_t K_Warp_Tile_,
bool kPadM_,
bool kPadN_,
bool kPadK_>
using trait_ = gemm_traits_<ADataType_,
BDataType_,
AccDataType_,
CDataType_,
ALayout_,
BLayout_,
CLayout_,
M_Tile_,
N_Tile_,
K_Tile_,
M_Warp_,
N_Warp_,
K_Warp_,
M_Warp_Tile_,
N_Warp_Tile_,
K_Warp_Tile_,
kPadM_,
kPadN_,
kPadK_>;
float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
{
if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
// universal gemm compute bound RR
std::cout << "fp16 comp\n";
return gemm_<trait_<FP16,
FP16,
FP32,
FP16,
Row,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
}
else
{
// universal gemm memory bound RR
std::cout << "fp16 mem\n";
return gemm_<trait_<FP16,
FP16,
FP32,
FP16,
Row,
Row,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(a, s);
}
}
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
// universal gemm compute bound RC
std::cout << "fp16 comp RC\n";
return gemm_<trait_<FP16,
FP16,
FP32,
FP16,
Row,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
}
else
{
// universal gemm memory bound RC
std::cout << "fp16 mem RC\n";
return gemm_<trait_<FP16,
FP16,
FP32,
FP16,
Row,
Col,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(a, s);
}
}
else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
// universal gemm compute bound CR
std::cout << "fp16 comp CR\n";
return gemm_<trait_<FP16,
FP16,
FP32,
FP16,
Col,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
}
else
{
// universal gemm memory bound CR
std::cout << "fp16 mem CR\n";
return gemm_<trait_<FP16,
FP16,
FP32,
FP16,
Col,
Row,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(a, s);
}
}
else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
// universal gemm compute bound CC
std::cout << "fp16 comp CC\n";
return gemm_<trait_<FP16,
FP16,
FP32,
FP16,
Col,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
}
else
{
// universal gemm memory bound CC
std::cout << "fp16 mem CC\n";
return gemm_<trait_<FP16,
FP16,
FP32,
FP16,
Col,
Col,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(a, s);
}
}
else
{
throw std::runtime_error("Wrong! ColumnMajor layout not supported for C Matrix!\n");
}
}
else if(t.data_type.compare("bf16") == 0)
{
if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
// universal gemm compute bound RR
std::cout << "bf16 comp\n";
return gemm_<trait_<BF16,
BF16,
FP32,
BF16,
Row,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
}
else
{
// universal gemm memory bound RR
std::cout << "bf16 mem\n";
return gemm_<trait_<BF16,
BF16,
FP32,
BF16,
Row,
Row,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(a, s);
}
}
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
// universal gemm compute bound RC
std::cout << "bf16 comp RC\n";
return gemm_<trait_<BF16,
BF16,
FP32,
BF16,
Row,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
}
else
{
// universal gemm memory bound RC
std::cout << "bf16 mem RC\n";
return gemm_<trait_<BF16,
BF16,
FP32,
BF16,
Row,
Col,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(a, s);
}
}
else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
// universal gemm compute bound CR
std::cout << "bf16 comp CR\n";
return gemm_<trait_<BF16,
BF16,
FP32,
BF16,
Col,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
}
else
{
// universal gemm memory bound CR
std::cout << "bf16 mem CR\n";
return gemm_<trait_<BF16,
BF16,
FP32,
BF16,
Col,
Row,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(a, s);
}
}
else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
// universal gemm compute bound CC
std::cout << "bf16 comp CC\n";
return gemm_<trait_<BF16,
BF16,
FP32,
BF16,
Col,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
}
else
{
// universal gemm memory bound CC
std::cout << "bf16 mem CC\n";
return gemm_<trait_<BF16,
BF16,
FP32,
BF16,
Col,
Col,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(a, s);
}
}
else
{
throw std::runtime_error("Wrong! ColumnMajor layout not supported for C Matrix!\n");
}
}
else
{
throw std::runtime_error("Wrong! DataTypes not supported!\n");
}
return 1.0f;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Col,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Col,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Row,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Row,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::half_t,
Col,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::half_t,
Col,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
template float gemm_<trait_<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::half_t,
Row,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::half_t,
Row,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include <iostream>
#include "gemm_basic.hpp"
using A = ck_tile::GemmHostArgs;
using S = ck_tile::stream_config;
template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
ck_tile::index_t M_Tile_,
ck_tile::index_t N_Tile_,
ck_tile::index_t K_Tile_,
ck_tile::index_t M_Warp_,
ck_tile::index_t N_Warp_,
ck_tile::index_t K_Warp_,
ck_tile::index_t M_Warp_Tile_,
ck_tile::index_t N_Warp_Tile_,
ck_tile::index_t K_Warp_Tile_,
bool kPadM_,
bool kPadN_,
bool kPadK_>
using trait_ = gemm_traits_<ADataType_,
BDataType_,
AccDataType_,
CDataType_,
ALayout_,
BLayout_,
CLayout_,
M_Tile_,
N_Tile_,
K_Tile_,
M_Warp_,
N_Warp_,
K_Warp_,
M_Warp_Tile_,
N_Warp_Tile_,
K_Warp_Tile_,
kPadM_,
kPadN_,
kPadK_>;
template <typename Traits_>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<Traits_::M_Tile, Traits_::N_Tile, Traits_::K_Tile>,
ck_tile::sequence<Traits_::M_Warp, Traits_::N_Warp, Traits_::K_Warp>,
ck_tile::sequence<Traits_::M_Warp_Tile, Traits_::N_Warp_Tile, Traits_::K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename Traits_::AccDataType,
typename Traits_::CDataType,
Traits_::kPadM,
Traits_::kPadN>>;
using GemmTraits = ck_tile::TileGemmTraits<Traits_::kPadM,
Traits_::kPadN,
Traits_::kPadK,
typename Traits_::ALayout,
typename Traits_::BLayout,
typename Traits_::CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
ck_tile::GemmPipelineProblem<typename Traits_::ADataType,
typename Traits_::BDataType,
typename Traits_::AccDataType,
GemmShape,
GemmTraits>>;
constexpr int kBlockPerCu = 1;
const ck_tile::index_t k_grain = args.k_batch * Traits_::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * Traits_::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<
ck_tile::UniversalGemmPipelineProblem<typename Traits_::ADataType,
typename Traits_::BDataType,
typename Traits_::AccDataType,
GemmShape,
GemmTraits,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
if(has_hot_loop)
{
// Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Two)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
{
static_assert(BaseGemmPipeline::PrefetchStages > 3);
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
{
if(tail_num == ck_tile::TailNumber::Four)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
{
if(tail_num == ck_tile::TailNumber::Six)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
}
}
}
else
{
// Tail number always Full - #PrefetchStages
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else
{
std::ostringstream err;
err << "When there's no hot loop, this tail number \"" << tail_num
<< "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
return ave_time;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Col,
Row,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Col,
Col,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Row,
Row,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Row,
Col,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::half_t,
Col,
Row,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::half_t,
Col,
Col,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
template float gemm_<trait_<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::half_t,
Row,
Row,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment