"driver/vscode:/vscode.git/clone" did not exist on "9d99a5807298c3f263d39a08328c3c68c930a900"
Commit e1e8e1ad authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

[CK_TILE] Use the GEMM prec input arg

parent 3cad16c4
......@@ -11,9 +11,17 @@
#include "gemm_basic.hpp"
template <typename ALayout, typename BLayout, typename CLayout>
template <typename ALayout, typename BLayout, typename CLayout, typename DataTypeConfig>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using Types = GemmBasicTypeConfig<DataTypeConfig>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
......@@ -100,23 +108,24 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
return ave_time;
}
float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
template <typename DataType>
float gemm_type_(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Row, Row, Row>(args, s);
return gemm_<Row, Row, Row, DataType>(args, s);
}
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Row, Col, Row>(args, s);
return gemm_<Row, Col, Row, DataType>(args, s);
}
else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Col, Row, Row>(args, s);
return gemm_<Col, Row, Row, DataType>(args, s);
}
else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Col, Col, Row>(args, s);
return gemm_<Col, Col, Row, DataType>(args, s);
}
else
{
......@@ -124,6 +133,19 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_til
}
}
float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if(t.data_type == "fp16") {
return gemm_type_<GemmFp16>(t, args, s);
}
else if(t.data_type == "bf16") {
return gemm_type_<GemmBf16>(t, args, s);
}
else {
throw std::runtime_error("Wrong! Data type not supported!\n");
}
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
......@@ -137,7 +159,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("prec", "bf16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
......
......@@ -10,11 +10,19 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/epilogue.hpp"
struct GemmFp16
{
};
struct GemmBf16
{
};
template <typename DataType>
struct GemmBasicTypeConfig;
template <>
struct GemmBasicTypeConfig<ck_tile::half_t>
struct GemmBasicTypeConfig<GemmFp16>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
......@@ -23,6 +31,15 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
// ToDo: Add more bias config to support different categories of GEMM.
};
template <>
struct GemmBasicTypeConfig<GemmBf16>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <typename T>
struct DataTypeTraits;
......@@ -44,13 +61,11 @@ struct DataTypeTraits<ck_tile::half_t>
static constexpr const char* name = "fp16";
};
using Types = GemmBasicTypeConfig<ck_tile::half_t>;
// Specific type aliases for easy access
using ADataType = Types::ADataType;
using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
......
......@@ -2,7 +2,7 @@
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename ALayout, typename BLayout, typename CLayout>
template <typename ALayout, typename BLayout, typename CLayout, typename DataTypeT>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
......@@ -16,6 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat)
{
using Types = GemmBasicTypeConfig<DataTypeT>;
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using CDataType = typename Types::CDataType;
ck_tile::GemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
......@@ -50,7 +55,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
return ave_time;
}
template <typename ALayout, typename BLayout, typename CLayout>
template <typename ALayout, typename BLayout, typename CLayout, typename DataTypeT>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
......@@ -61,6 +66,12 @@ int run_gemm_example_with_layouts(int argc,
if(!result)
return -1;
using Types = GemmBasicTypeConfig<DataTypeT>;
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
......@@ -129,7 +140,7 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
invoke_gemm<ALayout, BLayout, CLayout, DataTypeT>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
......@@ -209,33 +220,55 @@ int run_gemm_example_with_layouts(int argc,
return pass;
}
int run_gemm_example(int argc, char* argv[])
template <typename DataType>
int run_gemm_example_with_datatype(int argc, char* argv[], const std::string& a_layout, const std::string& b_layout)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
return run_gemm_example_with_layouts<Row, Row, Row, DataType>(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
return run_gemm_example_with_layouts<Row, Col, Row, DataType>(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
return run_gemm_example_with_layouts<Col, Col, Row, DataType>(argc, argv, Col{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
return run_gemm_example_with_layouts<Col, Row, Row, DataType>(argc, argv, Col{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string prec = arg_parser.get_str("prec");
if(prec == "fp16") {
return run_gemm_example_with_datatype<GemmFp16>(argc, argv, a_layout, b_layout);
}
else if(prec == "bf16")
{
return run_gemm_example_with_datatype<GemmBf16>(argc, argv, a_layout, b_layout);
}
else
{
throw std::runtime_error("Unsupported data type!");
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment