Commit e1e8e1ad authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

[CK_TILE] Use the GEMM prec input arg

parent 3cad16c4
...@@ -11,9 +11,17 @@ ...@@ -11,9 +11,17 @@
#include "gemm_basic.hpp" #include "gemm_basic.hpp"
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout, typename DataTypeConfig>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
using Types = GemmBasicTypeConfig<DataTypeConfig>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false; constexpr bool kPadM = false;
constexpr bool kPadN = false; constexpr bool kPadN = false;
...@@ -100,23 +108,24 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ...@@ -100,23 +108,24 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
return ave_time; return ave_time;
} }
float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) template <typename DataType>
float gemm_type_(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{ {
return gemm_<Row, Row, Row>(args, s); return gemm_<Row, Row, Row, DataType>(args, s);
} }
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{ {
return gemm_<Row, Col, Row>(args, s); return gemm_<Row, Col, Row, DataType>(args, s);
} }
else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{ {
return gemm_<Col, Row, Row>(args, s); return gemm_<Col, Row, Row, DataType>(args, s);
} }
else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{ {
return gemm_<Col, Col, Row>(args, s); return gemm_<Col, Col, Row, DataType>(args, s);
} }
else else
{ {
...@@ -124,6 +133,19 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_til ...@@ -124,6 +133,19 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_til
} }
} }
float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if(t.data_type == "fp16") {
return gemm_type_<GemmFp16>(t, args, s);
}
else if(t.data_type == "bf16") {
return gemm_type_<GemmBf16>(t, args, s);
}
else {
throw std::runtime_error("Wrong! Data type not supported!\n");
}
}
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
...@@ -137,7 +159,7 @@ auto create_args(int argc, char* argv[]) ...@@ -137,7 +159,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "Tensor B stride") .insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride") .insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "bf16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
......
...@@ -10,11 +10,19 @@ ...@@ -10,11 +10,19 @@
#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/epilogue.hpp"
struct GemmFp16
{
};
struct GemmBf16
{
};
template <typename DataType> template <typename DataType>
struct GemmBasicTypeConfig; struct GemmBasicTypeConfig;
template <> template <>
struct GemmBasicTypeConfig<ck_tile::half_t> struct GemmBasicTypeConfig<GemmFp16>
{ {
using ADataType = ck_tile::half_t; using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t; using BDataType = ck_tile::half_t;
...@@ -23,6 +31,15 @@ struct GemmBasicTypeConfig<ck_tile::half_t> ...@@ -23,6 +31,15 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
// ToDo: Add more bias config to support different categories of GEMM. // ToDo: Add more bias config to support different categories of GEMM.
}; };
template <>
struct GemmBasicTypeConfig<GemmBf16>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <typename T> template <typename T>
struct DataTypeTraits; struct DataTypeTraits;
...@@ -44,13 +61,11 @@ struct DataTypeTraits<ck_tile::half_t> ...@@ -44,13 +61,11 @@ struct DataTypeTraits<ck_tile::half_t>
static constexpr const char* name = "fp16"; static constexpr const char* name = "fp16";
}; };
using Types = GemmBasicTypeConfig<ck_tile::half_t>; template <>
struct DataTypeTraits<ck_tile::bf16_t>
// Specific type aliases for easy access {
using ADataType = Types::ADataType; static constexpr const char* name = "bf16";
using BDataType = Types::BDataType; };
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout, typename DataTypeT>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf,
...@@ -16,6 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -16,6 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup, int n_warmup,
int n_repeat) int n_repeat)
{ {
using Types = GemmBasicTypeConfig<DataTypeT>;
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using CDataType = typename Types::CDataType;
ck_tile::GemmHostArgs args; ck_tile::GemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
...@@ -50,7 +55,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -50,7 +55,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
return ave_time; return ave_time;
} }
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout, typename DataTypeT>
int run_gemm_example_with_layouts(int argc, int run_gemm_example_with_layouts(int argc,
char* argv[], char* argv[],
const ALayout a_layout = ALayout{}, const ALayout a_layout = ALayout{},
...@@ -61,6 +66,12 @@ int run_gemm_example_with_layouts(int argc, ...@@ -61,6 +66,12 @@ int run_gemm_example_with_layouts(int argc,
if(!result) if(!result)
return -1; return -1;
using Types = GemmBasicTypeConfig<DataTypeT>;
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k"); ck_tile::index_t K = arg_parser.get_int("k");
...@@ -129,7 +140,7 @@ int run_gemm_example_with_layouts(int argc, ...@@ -129,7 +140,7 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero(); c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero(); c_m_n_dev_result.SetZero();
invoke_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf, invoke_gemm<ALayout, BLayout, CLayout, DataTypeT>(a_m_k_dev_buf,
b_k_n_dev_buf, b_k_n_dev_buf,
c_m_n_dev_buf, c_m_n_dev_buf,
M, M,
...@@ -209,33 +220,55 @@ int run_gemm_example_with_layouts(int argc, ...@@ -209,33 +220,55 @@ int run_gemm_example_with_layouts(int argc,
return pass; return pass;
} }
int run_gemm_example(int argc, char* argv[]) template <typename DataType>
int run_gemm_example_with_datatype(int argc, char* argv[], const std::string& a_layout, const std::string& b_layout)
{ {
auto [result, arg_parser] = create_args(argc, argv); auto [result, arg_parser] = create_args(argc, argv);
if(!result) if(!result)
return -1; return -1;
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R") if(a_layout == "R" && b_layout == "R")
{ {
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); return run_gemm_example_with_layouts<Row, Row, Row, DataType>(argc, argv, Row{}, Row{}, Row{});
} }
else if(a_layout == "R" && b_layout == "C") else if(a_layout == "R" && b_layout == "C")
{ {
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); return run_gemm_example_with_layouts<Row, Col, Row, DataType>(argc, argv, Row{}, Col{}, Row{});
} }
else if(a_layout == "C" && b_layout == "C") else if(a_layout == "C" && b_layout == "C")
{ {
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); return run_gemm_example_with_layouts<Col, Col, Row, DataType>(argc, argv, Col{}, Col{}, Row{});
} }
else if(a_layout == "C" && b_layout == "R") else if(a_layout == "C" && b_layout == "R")
{ {
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); return run_gemm_example_with_layouts<Col, Row, Row, DataType>(argc, argv, Col{}, Row{}, Row{});
} }
else else
{ {
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
} }
} }
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string prec = arg_parser.get_str("prec");
if(prec == "fp16") {
return run_gemm_example_with_datatype<GemmFp16>(argc, argv, a_layout, b_layout);
}
else if(prec == "bf16")
{
return run_gemm_example_with_datatype<GemmBf16>(argc, argv, a_layout, b_layout);
}
else
{
throw std::runtime_error("Unsupported data type!");
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment