"tests/vscode:/vscode.git/clone" did not exist on "436e523bf1616bea22a1df78d93d7522311dccc8"
Commit c2945b96 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

add reviewers comments

parent aa30ef56
# add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
# add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp)
function (add_gemm_example TARGET_NAME MAIN_SRC)
message("adding ${TARGET_NAME}")
# not using add_example_executable() to add target, since we don't want this to have
......@@ -16,7 +13,7 @@ target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS})
set(COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template)
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
endfunction(add_gemm_example TARGET_NAME MAIN_SRC)
......
......@@ -55,12 +55,13 @@ using CDataType = Types::CDataType;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
/** \brief Struct used for specifying desired gemm details*/
struct gemm_traits
{
std::string data_type;
bool is_a_rowmajor;
bool is_b_rowmajor;
bool is_c_rowmajor;
std::string data_type; /** Tensors datatype, can be set to either fp16 or bf16*/
bool is_a_rowmajor; /** Whether A matrix is rowmajor */
bool is_b_rowmajor; /** Whether B matrix is rowmajor */
bool is_c_rowmajor; /** Whether C matrix is rowmajor */
};
template <typename ADataType_,
......@@ -106,9 +107,18 @@ struct gemm_traits_
};
// host API
template <typename Traits_>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
/**
* \brief Invoke gemm function
*
* \param traits Gemm traits which are used for choosing best instance.
* \param args Runtime gemm host arguments.
* \param s Stream configuration.
* \return Time of execution.
*/
float gemm(const gemm_traits& traits,
const ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s);
......@@ -9,7 +9,7 @@
#include <string>
#include <tuple>
#include "gemm_basic.hpp"
#include "gemm.hpp"
template <typename ALayout, typename BLayout, typename CLayout>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
......@@ -124,29 +124,6 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_til
}
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_basic.hpp"
#include "gemm.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
......@@ -10,45 +10,6 @@ using FP32 = float;
using FP16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
ck_tile::index_t M_Tile_,
ck_tile::index_t N_Tile_,
ck_tile::index_t K_Tile_,
ck_tile::index_t M_Warp_,
ck_tile::index_t N_Warp_,
ck_tile::index_t K_Warp_,
ck_tile::index_t M_Warp_Tile_,
ck_tile::index_t N_Warp_Tile_,
ck_tile::index_t K_Warp_Tile_,
bool kPadM_,
bool kPadN_,
bool kPadK_>
using trait_ = gemm_traits_<ADataType_,
BDataType_,
AccDataType_,
CDataType_,
ALayout_,
BLayout_,
CLayout_,
M_Tile_,
N_Tile_,
K_Tile_,
M_Warp_,
N_Warp_,
K_Warp_,
M_Warp_Tile_,
N_Warp_Tile_,
K_Warp_Tile_,
kPadM_,
kPadN_,
kPadK_>;
float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
......@@ -57,9 +18,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{
if(a.M > 512)
{
// universal gemm compute bound RR
std::cout << "fp16 comp\n";
return gemm_<trait_<FP16,
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
......@@ -81,9 +40,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
}
else
{
// universal gemm memory bound RR
std::cout << "fp16 mem\n";
return gemm_<trait_<FP16,
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
......@@ -108,9 +65,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{
if(a.M > 512)
{
// universal gemm compute bound RC
std::cout << "fp16 comp RC\n";
return gemm_<trait_<FP16,
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
......@@ -132,9 +87,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
}
else
{
// universal gemm memory bound RC
std::cout << "fp16 mem RC\n";
return gemm_<trait_<FP16,
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
......@@ -159,9 +112,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{
if(a.M > 512)
{
// universal gemm compute bound CR
std::cout << "fp16 comp CR\n";
return gemm_<trait_<FP16,
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
......@@ -183,9 +134,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
}
else
{
// universal gemm memory bound CR
std::cout << "fp16 mem CR\n";
return gemm_<trait_<FP16,
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
......@@ -210,9 +159,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{
if(a.M > 512)
{
// universal gemm compute bound CC
std::cout << "fp16 comp CC\n";
return gemm_<trait_<FP16,
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
......@@ -234,9 +181,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
}
else
{
// universal gemm memory bound CC
std::cout << "fp16 mem CC\n";
return gemm_<trait_<FP16,
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
......@@ -268,9 +213,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{
if(a.M > 512)
{
// universal gemm compute bound RR
std::cout << "bf16 comp\n";
return gemm_<trait_<BF16,
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
......@@ -292,9 +235,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
}
else
{
// universal gemm memory bound RR
std::cout << "bf16 mem\n";
return gemm_<trait_<BF16,
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
......@@ -319,9 +260,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{
if(a.M > 512)
{
// universal gemm compute bound RC
std::cout << "bf16 comp RC\n";
return gemm_<trait_<BF16,
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
......@@ -343,9 +282,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
}
else
{
// universal gemm memory bound RC
std::cout << "bf16 mem RC\n";
return gemm_<trait_<BF16,
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
......@@ -370,9 +307,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{
if(a.M > 512)
{
// universal gemm compute bound CR
std::cout << "bf16 comp CR\n";
return gemm_<trait_<BF16,
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
......@@ -394,9 +329,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
}
else
{
// universal gemm memory bound CR
std::cout << "bf16 mem CR\n";
return gemm_<trait_<BF16,
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
......@@ -421,9 +354,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{
if(a.M > 512)
{
// universal gemm compute bound CC
std::cout << "bf16 comp CC\n";
return gemm_<trait_<BF16,
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
......@@ -445,9 +376,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
}
else
{
// universal gemm memory bound CC
std::cout << "bf16 mem CC\n";
return gemm_<trait_<BF16,
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
......
......@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Col,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Col, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
......@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Col,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Col, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
......@@ -5,22 +5,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Row,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Row, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
......@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Row,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Row, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
......@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::half_t,
Col,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Col, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
......@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::half_t,
Col,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Col, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
......@@ -5,22 +5,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
template float gemm_<trait_<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::half_t,
Row,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Row, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
......@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::half_t,
Row,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Row, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
......@@ -2,50 +2,11 @@
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include <iostream>
#include "gemm_basic.hpp"
#include "gemm.hpp"
using A = ck_tile::GemmHostArgs;
using S = ck_tile::stream_config;
template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
ck_tile::index_t M_Tile_,
ck_tile::index_t N_Tile_,
ck_tile::index_t K_Tile_,
ck_tile::index_t M_Warp_,
ck_tile::index_t N_Warp_,
ck_tile::index_t K_Warp_,
ck_tile::index_t M_Warp_Tile_,
ck_tile::index_t N_Warp_Tile_,
ck_tile::index_t K_Warp_Tile_,
bool kPadM_,
bool kPadN_,
bool kPadK_>
using trait_ = gemm_traits_<ADataType_,
BDataType_,
AccDataType_,
CDataType_,
ALayout_,
BLayout_,
CLayout_,
M_Tile_,
N_Tile_,
K_Tile_,
M_Warp_,
N_Warp_,
K_Warp_,
M_Warp_Tile_,
N_Warp_Tile_,
K_Warp_Tile_,
kPadM_,
kPadN_,
kPadK_>;
template <typename Traits_>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
......
......@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Col,
Row,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Col, Row, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
// clang-format on
......@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Col,
Col,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Col, Col, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
// clang-format on
......@@ -5,22 +5,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Row,
Row,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Row, Row, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
// clang-format on
......@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
ck_tile::bf16_t,
Row,
Col,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Row, Col, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
// clang-format on
......@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::half_t,
Col,
Row,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Col, Row, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
// clang-format on
......@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::half_t,
Col,
Col,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Col, Col, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
// clang-format on
......@@ -5,22 +5,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
template float gemm_<trait_<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::half_t,
Row,
Row,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Row, Row, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
// clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment