Commit 3de7bd67 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

[CK_TILE] Use the GEMM example prec input arg

parent e1e8e1ad
......@@ -109,7 +109,9 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
}
template <typename DataType>
float gemm_type_(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
float gemm_type_(const gemm_traits& t,
const ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s)
{
if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
......@@ -135,13 +137,16 @@ float gemm_type_(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const
float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if(t.data_type == "fp16") {
if(t.data_type == "fp16")
{
return gemm_type_<GemmFp16>(t, args, s);
}
else if(t.data_type == "bf16") {
else if(t.data_type == "bf16")
{
return gemm_type_<GemmBf16>(t, args, s);
}
else {
else
{
throw std::runtime_error("Wrong! Data type not supported!\n");
}
}
......@@ -159,7 +164,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "bf16", "data type. fp16/bf16/fp8/bf8")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
......
......@@ -2,7 +2,7 @@
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename ALayout, typename BLayout, typename CLayout, typename DataTypeT>
template <typename ALayout, typename BLayout, typename CLayout, typename DataType>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
......@@ -16,10 +16,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat)
{
using Types = GemmBasicTypeConfig<DataTypeT>;
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using CDataType = typename Types::CDataType;
using Types = GemmBasicTypeConfig<DataType>;
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using CDataType = typename Types::CDataType;
ck_tile::GemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
......@@ -55,7 +55,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
return ave_time;
}
template <typename ALayout, typename BLayout, typename CLayout, typename DataTypeT>
template <typename ALayout, typename BLayout, typename CLayout, typename DataType>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
......@@ -66,7 +66,7 @@ int run_gemm_example_with_layouts(int argc,
if(!result)
return -1;
using Types = GemmBasicTypeConfig<DataTypeT>;
using Types = GemmBasicTypeConfig<DataType>;
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
......@@ -140,18 +140,18 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ALayout, BLayout, CLayout, DataTypeT>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
invoke_gemm<ALayout, BLayout, CLayout, DataType>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
......@@ -221,7 +221,10 @@ int run_gemm_example_with_layouts(int argc,
}
template <typename DataType>
int run_gemm_example_with_datatype(int argc, char* argv[], const std::string& a_layout, const std::string& b_layout)
int run_gemm_example_with_datatype(int argc,
char* argv[],
const std::string& a_layout,
const std::string& b_layout)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
......@@ -229,19 +232,23 @@ int run_gemm_example_with_datatype(int argc, char* argv[], const std::string& a_
if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts<Row, Row, Row, DataType>(argc, argv, Row{}, Row{}, Row{});
return run_gemm_example_with_layouts<Row, Row, Row, DataType>(
argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<Row, Col, Row, DataType>(argc, argv, Row{}, Col{}, Row{});
return run_gemm_example_with_layouts<Row, Col, Row, DataType>(
argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts<Col, Col, Row, DataType>(argc, argv, Col{}, Col{}, Row{});
return run_gemm_example_with_layouts<Col, Col, Row, DataType>(
argc, argv, Col{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts<Col, Row, Row, DataType>(argc, argv, Col{}, Row{}, Row{});
return run_gemm_example_with_layouts<Col, Row, Row, DataType>(
argc, argv, Col{}, Row{}, Row{});
}
else
{
......@@ -257,9 +264,10 @@ int run_gemm_example(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string prec = arg_parser.get_str("prec");
std::string prec = arg_parser.get_str("prec");
if(prec == "fp16") {
if(prec == "fp16")
{
return run_gemm_example_with_datatype<GemmFp16>(argc, argv, a_layout, b_layout);
}
else if(prec == "bf16")
......@@ -270,5 +278,4 @@ int run_gemm_example(int argc, char* argv[])
{
throw std::runtime_error("Unsupported data type!");
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment