Commit 3de7bd67 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

[CK_TILE] Use the GEMM example prec input arg

parent e1e8e1ad
...@@ -109,7 +109,9 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ...@@ -109,7 +109,9 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
} }
template <typename DataType> template <typename DataType>
float gemm_type_(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) float gemm_type_(const gemm_traits& t,
const ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s)
{ {
if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{ {
...@@ -135,13 +137,16 @@ float gemm_type_(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ...@@ -135,13 +137,16 @@ float gemm_type_(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const
float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
if(t.data_type == "fp16") { if(t.data_type == "fp16")
{
return gemm_type_<GemmFp16>(t, args, s); return gemm_type_<GemmFp16>(t, args, s);
} }
else if(t.data_type == "bf16") { else if(t.data_type == "bf16")
{
return gemm_type_<GemmBf16>(t, args, s); return gemm_type_<GemmBf16>(t, args, s);
} }
else { else
{
throw std::runtime_error("Wrong! Data type not supported!\n"); throw std::runtime_error("Wrong! Data type not supported!\n");
} }
} }
...@@ -159,7 +164,7 @@ auto create_args(int argc, char* argv[]) ...@@ -159,7 +164,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "Tensor B stride") .insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride") .insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "bf16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
template <typename ALayout, typename BLayout, typename CLayout, typename DataTypeT> template <typename ALayout, typename BLayout, typename CLayout, typename DataType>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf,
...@@ -16,10 +16,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -16,10 +16,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup, int n_warmup,
int n_repeat) int n_repeat)
{ {
using Types = GemmBasicTypeConfig<DataTypeT>; using Types = GemmBasicTypeConfig<DataType>;
using ADataType = typename Types::ADataType; using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType; using BDataType = typename Types::BDataType;
using CDataType = typename Types::CDataType; using CDataType = typename Types::CDataType;
ck_tile::GemmHostArgs args; ck_tile::GemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
...@@ -55,7 +55,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -55,7 +55,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
return ave_time; return ave_time;
} }
template <typename ALayout, typename BLayout, typename CLayout, typename DataTypeT> template <typename ALayout, typename BLayout, typename CLayout, typename DataType>
int run_gemm_example_with_layouts(int argc, int run_gemm_example_with_layouts(int argc,
char* argv[], char* argv[],
const ALayout a_layout = ALayout{}, const ALayout a_layout = ALayout{},
...@@ -66,7 +66,7 @@ int run_gemm_example_with_layouts(int argc, ...@@ -66,7 +66,7 @@ int run_gemm_example_with_layouts(int argc,
if(!result) if(!result)
return -1; return -1;
using Types = GemmBasicTypeConfig<DataTypeT>; using Types = GemmBasicTypeConfig<DataType>;
using ADataType = typename Types::ADataType; using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType; using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType; using AccDataType = typename Types::AccDataType;
...@@ -140,18 +140,18 @@ int run_gemm_example_with_layouts(int argc, ...@@ -140,18 +140,18 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero(); c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero(); c_m_n_dev_result.SetZero();
invoke_gemm<ALayout, BLayout, CLayout, DataTypeT>(a_m_k_dev_buf, invoke_gemm<ALayout, BLayout, CLayout, DataType>(a_m_k_dev_buf,
b_k_n_dev_buf, b_k_n_dev_buf,
c_m_n_dev_buf, c_m_n_dev_buf,
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
stride_C, stride_C,
kbatch, kbatch,
n_warmup, n_warmup,
n_repeat); n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true; bool pass = true;
...@@ -221,7 +221,10 @@ int run_gemm_example_with_layouts(int argc, ...@@ -221,7 +221,10 @@ int run_gemm_example_with_layouts(int argc,
} }
template <typename DataType> template <typename DataType>
int run_gemm_example_with_datatype(int argc, char* argv[], const std::string& a_layout, const std::string& b_layout) int run_gemm_example_with_datatype(int argc,
char* argv[],
const std::string& a_layout,
const std::string& b_layout)
{ {
auto [result, arg_parser] = create_args(argc, argv); auto [result, arg_parser] = create_args(argc, argv);
if(!result) if(!result)
...@@ -229,19 +232,23 @@ int run_gemm_example_with_datatype(int argc, char* argv[], const std::string& a_ ...@@ -229,19 +232,23 @@ int run_gemm_example_with_datatype(int argc, char* argv[], const std::string& a_
if(a_layout == "R" && b_layout == "R") if(a_layout == "R" && b_layout == "R")
{ {
return run_gemm_example_with_layouts<Row, Row, Row, DataType>(argc, argv, Row{}, Row{}, Row{}); return run_gemm_example_with_layouts<Row, Row, Row, DataType>(
argc, argv, Row{}, Row{}, Row{});
} }
else if(a_layout == "R" && b_layout == "C") else if(a_layout == "R" && b_layout == "C")
{ {
return run_gemm_example_with_layouts<Row, Col, Row, DataType>(argc, argv, Row{}, Col{}, Row{}); return run_gemm_example_with_layouts<Row, Col, Row, DataType>(
argc, argv, Row{}, Col{}, Row{});
} }
else if(a_layout == "C" && b_layout == "C") else if(a_layout == "C" && b_layout == "C")
{ {
return run_gemm_example_with_layouts<Col, Col, Row, DataType>(argc, argv, Col{}, Col{}, Row{}); return run_gemm_example_with_layouts<Col, Col, Row, DataType>(
argc, argv, Col{}, Col{}, Row{});
} }
else if(a_layout == "C" && b_layout == "R") else if(a_layout == "C" && b_layout == "R")
{ {
return run_gemm_example_with_layouts<Col, Row, Row, DataType>(argc, argv, Col{}, Row{}, Row{}); return run_gemm_example_with_layouts<Col, Row, Row, DataType>(
argc, argv, Col{}, Row{}, Row{});
} }
else else
{ {
...@@ -257,9 +264,10 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -257,9 +264,10 @@ int run_gemm_example(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("a_layout"); std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout"); std::string b_layout = arg_parser.get_str("b_layout");
std::string prec = arg_parser.get_str("prec"); std::string prec = arg_parser.get_str("prec");
if(prec == "fp16") { if(prec == "fp16")
{
return run_gemm_example_with_datatype<GemmFp16>(argc, argv, a_layout, b_layout); return run_gemm_example_with_datatype<GemmFp16>(argc, argv, a_layout, b_layout);
} }
else if(prec == "bf16") else if(prec == "bf16")
...@@ -270,5 +278,4 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -270,5 +278,4 @@ int run_gemm_example(int argc, char* argv[])
{ {
throw std::runtime_error("Unsupported data type!"); throw std::runtime_error("Unsupported data type!");
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment