Commit c2945b96 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

add reviewers comments

parent aa30ef56
# add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
# add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp)
function (add_gemm_example TARGET_NAME MAIN_SRC) function (add_gemm_example TARGET_NAME MAIN_SRC)
message("adding ${TARGET_NAME}") message("adding ${TARGET_NAME}")
# not using add_example_executable() to add target, since we don't want this to have # not using add_example_executable() to add target, since we don't want this to have
...@@ -16,7 +13,7 @@ target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS}) ...@@ -16,7 +13,7 @@ target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS})
set(COMPILE_OPTIONS) set(COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template)
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS}) target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
endfunction(add_gemm_example TARGET_NAME MAIN_SRC) endfunction(add_gemm_example TARGET_NAME MAIN_SRC)
......
...@@ -55,12 +55,13 @@ using CDataType = Types::CDataType; ...@@ -55,12 +55,13 @@ using CDataType = Types::CDataType;
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
/** \brief Struct used for specifying desired gemm details*/
struct gemm_traits struct gemm_traits
{ {
std::string data_type; std::string data_type; /** Tensors datatype, can be set to either fp16 or bf16*/
bool is_a_rowmajor; bool is_a_rowmajor; /** Whether A matrix is rowmajor */
bool is_b_rowmajor; bool is_b_rowmajor; /** Whether B matrix is rowmajor */
bool is_c_rowmajor; bool is_c_rowmajor; /** Whether C matrix is rowmajor */
}; };
template <typename ADataType_, template <typename ADataType_,
...@@ -106,9 +107,18 @@ struct gemm_traits_ ...@@ -106,9 +107,18 @@ struct gemm_traits_
}; };
// host API // host API
template <typename Traits_> template <typename Traits_>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
/**
* \brief Invoke gemm function
*
* \param traits Gemm traits which are used for choosing best instance.
* \param args Runtime gemm host arguments.
* \param s Stream configuration.
* \return Time of execution.
*/
float gemm(const gemm_traits& traits, float gemm(const gemm_traits& traits,
const ck_tile::GemmHostArgs& args, const ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s); const ck_tile::stream_config& s);
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
#include "gemm_basic.hpp" #include "gemm.hpp"
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
...@@ -124,29 +124,6 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_til ...@@ -124,29 +124,6 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_til
} }
} }
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -6,22 +6,6 @@ ...@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t, // clang-format off
ck_tile::bf16_t, template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Col, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
float, // clang-format on
ck_tile::bf16_t,
Col,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
...@@ -6,22 +6,6 @@ ...@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t, // clang-format off
ck_tile::bf16_t, template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Col, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
float, // clang-format on
ck_tile::bf16_t,
Col,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
...@@ -5,22 +5,6 @@ ...@@ -5,22 +5,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
template float gemm_<trait_<ck_tile::bf16_t, // clang-format off
ck_tile::bf16_t, template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Row, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
float, // clang-format on
ck_tile::bf16_t,
Row,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
...@@ -6,22 +6,6 @@ ...@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t, // clang-format off
ck_tile::bf16_t, template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Row, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
float, // clang-format on
ck_tile::bf16_t,
Row,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
...@@ -6,22 +6,6 @@ ...@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t, // clang-format off
ck_tile::half_t, template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Col, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
float, // clang-format on
ck_tile::half_t,
Col,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
...@@ -6,22 +6,6 @@ ...@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t, // clang-format off
ck_tile::half_t, template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Col, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
float, // clang-format on
ck_tile::half_t,
Col,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
...@@ -5,22 +5,6 @@ ...@@ -5,22 +5,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
template float gemm_<trait_<ck_tile::half_t, // clang-format off
ck_tile::half_t, template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Row, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
float, // clang-format on
ck_tile::half_t,
Row,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
...@@ -6,22 +6,6 @@ ...@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t, // clang-format off
ck_tile::half_t, template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Row, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
float, // clang-format on
ck_tile::half_t,
Row,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(const A&, const S&);
...@@ -2,50 +2,11 @@ ...@@ -2,50 +2,11 @@
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include <iostream> #include <iostream>
#include "gemm_basic.hpp" #include "gemm.hpp"
using A = ck_tile::GemmHostArgs; using A = ck_tile::GemmHostArgs;
using S = ck_tile::stream_config; using S = ck_tile::stream_config;
template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
ck_tile::index_t M_Tile_,
ck_tile::index_t N_Tile_,
ck_tile::index_t K_Tile_,
ck_tile::index_t M_Warp_,
ck_tile::index_t N_Warp_,
ck_tile::index_t K_Warp_,
ck_tile::index_t M_Warp_Tile_,
ck_tile::index_t N_Warp_Tile_,
ck_tile::index_t K_Warp_Tile_,
bool kPadM_,
bool kPadN_,
bool kPadK_>
using trait_ = gemm_traits_<ADataType_,
BDataType_,
AccDataType_,
CDataType_,
ALayout_,
BLayout_,
CLayout_,
M_Tile_,
N_Tile_,
K_Tile_,
M_Warp_,
N_Warp_,
K_Warp_,
M_Warp_Tile_,
N_Warp_Tile_,
K_Warp_Tile_,
kPadM_,
kPadN_,
kPadK_>;
template <typename Traits_> template <typename Traits_>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
......
...@@ -6,22 +6,6 @@ ...@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t, // clang-format off
ck_tile::bf16_t, template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Col, Row, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
float, // clang-format on
ck_tile::bf16_t,
Col,
Row,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
...@@ -6,22 +6,6 @@ ...@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t, // clang-format off
ck_tile::bf16_t, template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Col, Col, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
float, // clang-format on
ck_tile::bf16_t,
Col,
Col,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
...@@ -5,22 +5,6 @@ ...@@ -5,22 +5,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
template float gemm_<trait_<ck_tile::bf16_t, // clang-format off
ck_tile::bf16_t, template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Row, Row, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
float, // clang-format on
ck_tile::bf16_t,
Row,
Row,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
...@@ -6,22 +6,6 @@ ...@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::bf16_t, // clang-format off
ck_tile::bf16_t, template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Row, Col, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
float, // clang-format on
ck_tile::bf16_t,
Row,
Col,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
...@@ -6,22 +6,6 @@ ...@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t, // clang-format off
ck_tile::half_t, template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Col, Row, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
float, // clang-format on
ck_tile::half_t,
Col,
Row,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
...@@ -6,22 +6,6 @@ ...@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t, // clang-format off
ck_tile::half_t, template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Col, Col, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
float, // clang-format on
ck_tile::half_t,
Col,
Col,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
...@@ -5,22 +5,6 @@ ...@@ -5,22 +5,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
template float gemm_<trait_<ck_tile::half_t, // clang-format off
ck_tile::half_t, template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Row, Row, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
float, // clang-format on
ck_tile::half_t,
Row,
Row,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment