Commit 4511f877 authored by Chao Liu's avatar Chao Liu
Browse files

refactor profiler

parent 519b6aaf
...@@ -60,8 +60,6 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -60,8 +60,6 @@ bool profile_gemm_reduce_impl(int do_verification,
int StrideB, int StrideB,
int StrideC) int StrideC)
{ {
bool pass = true;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value) if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
...@@ -209,15 +207,13 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -209,15 +207,13 @@ bool profile_gemm_reduce_impl(int do_verification,
} }
} }
if(gemm_ptrs.size() <= 0) std::cout << "found " << gemm_ptrs.size() << " instances" << std::endl;
{
throw std::runtime_error("wrong! no device GEMM instance found");
}
std::string best_gemm_name; std::string best_gemm_name;
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
bool pass = true;
// profile device GEMM instances // profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs) for(auto& gemm_ptr : gemm_ptrs)
......
#pragma once
#include <iomanip>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_conv.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "element_wise_operation.hpp"
#include "device_gemm.hpp"
#include "reference_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
bool profile_gemm_splitk_impl(int do_verification,
int init_method,
bool do_log,
int nrepeat,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideC,
int KBatch)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
std::size_t num_thread = 1;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
}
// set zero to c_device_buf
c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs;
if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
is_same<CDataType, float>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
}
}
else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
}
}
std::cout << "found " << gemm_ptrs.size() << " instances" << std::endl;
std::string best_gemm_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
bool pass = true;
// profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
KBatch);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
std::string gemm_name = gemm_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << gemm_name << std::endl;
if(tflops > best_tflops)
{
best_gemm_name = gemm_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
pass = pass &&
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
}
}
else
{
std::cout << "does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck
...@@ -46,7 +46,7 @@ template <typename ADataType, ...@@ -46,7 +46,7 @@ template <typename ADataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout> typename CLayout>
void profile_grouped_gemm_impl(int do_verification, bool profile_grouped_gemm_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
int nrepeat, int nrepeat,
...@@ -57,6 +57,8 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -57,6 +57,8 @@ void profile_grouped_gemm_impl(int do_verification,
std::vector<int> StrideBs, std::vector<int> StrideBs,
std::vector<int> StrideCs) std::vector<int> StrideCs)
{ {
bool pass = true;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value) if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
...@@ -81,6 +83,7 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -81,6 +83,7 @@ void profile_grouped_gemm_impl(int do_verification,
std::vector<Tensor<ADataType>> a_m_k; std::vector<Tensor<ADataType>> a_m_k;
std::vector<Tensor<BDataType>> b_k_n; std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n_host_results;
std::vector<Tensor<CDataType>> c_m_n_device_results; std::vector<Tensor<CDataType>> c_m_n_device_results;
for(int i = 0; i < Ms.size(); i++) for(int i = 0; i < Ms.size(); i++)
...@@ -90,6 +93,9 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -90,6 +93,9 @@ void profile_grouped_gemm_impl(int do_verification,
b_k_n.push_back( b_k_n.push_back(
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
c_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
c_m_n_device_results.push_back( c_m_n_device_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
...@@ -121,11 +127,6 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -121,11 +127,6 @@ void profile_grouped_gemm_impl(int do_verification,
const auto b_element_op = BElementOp{}; const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{}; const auto c_element_op = CElementOp{};
// if(do_verification)
// {
// }
using DeviceMemPtr = std::unique_ptr<DeviceMem>; using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, c_device_buf; std::vector<DeviceMemPtr> a_device_buf, b_device_buf, c_device_buf;
...@@ -165,6 +166,27 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -165,6 +166,27 @@ void profile_grouped_gemm_impl(int do_verification,
p_c.push_back(c_device_buf[i]->GetDeviceBuffer()); p_c.push_back(c_device_buf[i]->GetDeviceBuffer());
} }
// reference calculation
if(do_verification)
{
for(int i = 0; i < gemm_shapes.size(); i++)
{
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
b_k_n[i],
c_m_n_host_results[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
}
}
// add device GEMM instances // add device GEMM instances
std::vector< std::vector<
ck::tensor_operation::device::device_grouped_gemm_instance::DeviceGroupedGemmNoOpPtr> ck::tensor_operation::device::device_grouped_gemm_instance::DeviceGroupedGemmNoOpPtr>
...@@ -229,6 +251,12 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -229,6 +251,12 @@ void profile_grouped_gemm_impl(int do_verification,
if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
// re-init C to zero before profiling next kernel
for(int i = 0; i < gemm_shapes.size(); i++)
{
c_device_buf[i]->SetZero();
}
std::string gemm_name = gemm_ptr->GetTypeString(); std::string gemm_name = gemm_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
...@@ -260,32 +288,10 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -260,32 +288,10 @@ void profile_grouped_gemm_impl(int do_verification,
{ {
for(int i = 0; i < gemm_shapes.size(); i++) for(int i = 0; i < gemm_shapes.size(); i++)
{ {
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
Tensor<CDataType> c_m_n_host_result( pass = pass && ck::utils::check_err(c_m_n_device_results[i].mData,
f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})); c_m_n_host_results[i].mData);
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
b_k_n[i],
c_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_results[i].mData, c_m_n_host_result.mData);
if(do_log) if(do_log)
{ {
...@@ -296,7 +302,7 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -296,7 +302,7 @@ void profile_grouped_gemm_impl(int do_verification,
std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",") std::cout << "c_host : ", c_m_n_host_results[i].mData, ",")
<< std::endl; << std::endl;
} }
} }
...@@ -310,6 +316,9 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -310,6 +316,9 @@ void profile_grouped_gemm_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
return pass;
} // namespace profiler } // namespace profiler
} // namespace profiler } // namespace profiler
......
...@@ -16,28 +16,28 @@ ...@@ -16,28 +16,28 @@
#include "device_batched_gemm_xdl.hpp" #include "device_batched_gemm_xdl.hpp"
#include "profile_batched_gemm_impl.hpp" #include "profile_batched_gemm_impl.hpp"
enum struct GemmMatrixLayout bool profile_batched_gemm(int argc, char* argv[])
{ {
MK_KN_MN, // 0 enum struct GemmMatrixLayout
MK_NK_MN, // 1 {
KM_KN_MN, // 2 MK_KN_MN, // 0
KM_NK_MN, // 3 MK_NK_MN, // 1
MK_KN_NM, // 4 KM_KN_MN, // 2
MK_NK_NM, // 5 KM_NK_MN, // 3
KM_KN_NM, // 6 MK_KN_NM, // 4
KM_NK_NM, // 7 MK_NK_NM, // 5
}; KM_KN_NM, // 6
KM_NK_NM, // 7
};
enum struct GemmDataType enum struct GemmDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
BF16_BF16_BF16, // 2 BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3 INT8_INT8_INT8, // 3
}; };
int profile_batched_gemm(int argc, char* argv[])
{
if(!(argc == 15)) if(!(argc == 15))
{ {
printf("arg1: tensor operation (batched_gemm: Batched GEMM)\n"); printf("arg1: tensor operation (batched_gemm: Batched GEMM)\n");
...@@ -51,7 +51,7 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -51,7 +51,7 @@ int profile_batched_gemm(int argc, char* argv[])
printf("arg8: print tensor value (0: no; 1: yes)\n"); printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n"); printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n"); printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n");
exit(1); return false;
} }
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
...@@ -73,12 +73,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -73,12 +73,12 @@ int profile_batched_gemm(int argc, char* argv[])
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_batched_gemm_impl<ck::half_t, return ck::profiler::profile_batched_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -93,12 +93,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -93,12 +93,12 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_batched_gemm_impl<ck::half_t, return ck::profiler::profile_batched_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -113,12 +113,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -113,12 +113,12 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_batched_gemm_impl<ck::half_t, return ck::profiler::profile_batched_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -133,12 +133,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -133,12 +133,12 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_batched_gemm_impl<ck::half_t, return ck::profiler::profile_batched_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -153,12 +153,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -153,12 +153,12 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t, return ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -173,12 +173,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -173,12 +173,12 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t, return ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -193,12 +193,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -193,12 +193,12 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t, return ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -213,12 +213,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -213,12 +213,12 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t, return ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -233,12 +233,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -233,12 +233,12 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_batched_gemm_impl<float, return ck::profiler::profile_batched_gemm_impl<float,
float, float,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -253,12 +253,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -253,12 +253,12 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_batched_gemm_impl<float, return ck::profiler::profile_batched_gemm_impl<float,
float, float,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -273,12 +273,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -273,12 +273,12 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_batched_gemm_impl<float, return ck::profiler::profile_batched_gemm_impl<float,
float, float,
float, float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -293,12 +293,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -293,12 +293,12 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_batched_gemm_impl<float, return ck::profiler::profile_batched_gemm_impl<float,
float, float,
float, float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -313,12 +313,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -313,12 +313,12 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_batched_gemm_impl<int8_t, return ck::profiler::profile_batched_gemm_impl<int8_t,
int8_t, int8_t,
int8_t, int8_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -333,12 +333,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -333,12 +333,12 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_batched_gemm_impl<int8_t, return ck::profiler::profile_batched_gemm_impl<int8_t,
int8_t, int8_t,
int8_t, int8_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -353,12 +353,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -353,12 +353,12 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_batched_gemm_impl<int8_t, return ck::profiler::profile_batched_gemm_impl<int8_t,
int8_t, int8_t,
int8_t, int8_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -373,12 +373,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -373,12 +373,12 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_batched_gemm_impl<int8_t, return ck::profiler::profile_batched_gemm_impl<int8_t,
int8_t, int8_t,
int8_t, int8_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -393,8 +393,8 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -393,8 +393,8 @@ int profile_batched_gemm(int argc, char* argv[])
} }
else else
{ {
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); std::cout << "this data_type & layout is not implemented" << std::endl;
}
return 1; return true;
}
} }
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "profile_batched_gemm_reduce_impl.hpp" #include "profile_batched_gemm_reduce_impl.hpp"
int profile_batched_gemm_reduce(int argc, char* argv[]) bool profile_batched_gemm_reduce(int argc, char* argv[])
{ {
enum struct GemmMatrixLayout enum struct GemmMatrixLayout
{ {
...@@ -23,7 +23,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -23,7 +23,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
F16_F16_F16_F32_F32, // 1 F16_F16_F16_F32_F32, // 1
}; };
if(!(argc == 15 || argc == 16)) if(argc != 15)
{ {
printf("arg1: tensor operation (batched_gemm: BatchedGEMM+Reduce)\n"); printf("arg1: tensor operation (batched_gemm: BatchedGEMM+Reduce)\n");
printf("arg2: data type (0: fp32; 1: fp16)\n"); printf("arg2: data type (0: fp32; 1: fp16)\n");
...@@ -36,8 +36,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -36,8 +36,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
printf("arg8: print tensor value (0: no; 1: yes)\n"); printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n"); printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n"); printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n");
printf("arg15: split k into mulitiple batch\n"); return false;
exit(1);
} }
const auto data_type = static_cast<GemmReduceDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmReduceDataType>(std::stoi(argv[2]));
...@@ -59,13 +58,13 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -59,13 +58,13 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t, return ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -81,13 +80,13 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -81,13 +80,13 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout::MK_NK_MN) layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t, return ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -103,13 +102,13 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -103,13 +102,13 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout::KM_KN_MN) layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t, return ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float, float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -125,13 +124,13 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -125,13 +124,13 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout::KM_NK_MN) layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t, return ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float, float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -146,8 +145,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -146,8 +145,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
} }
else else
{ {
throw std::runtime_error("wrong! this data_type & layout is not implemented"); std::cout << "this data_type & layout is not implemented" << std::endl;
}
return 1; return true;
}
} }
...@@ -4,36 +4,37 @@ ...@@ -4,36 +4,37 @@
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> #include <half.hpp>
#include "profile_conv_bwd_data_impl.hpp" #include "profile_conv_bwd_data_impl.hpp"
enum struct ConvDataType int profile_conv_bwd_data(int argc, char* argv[])
{ {
F32_F32_F32, // 0 enum struct ConvDataType
F16_F16_F16, // 1 {
BF16_BF16_BF16, // 2 F32_F32_F32, // 0
INT8_INT8_INT8, // 3 F16_F16_F16, // 1
}; BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
};
enum struct ConvInputLayout enum struct ConvInputLayout
{ {
NCHW, // 0 NCHW, // 0
NHWC, // 1 NHWC, // 1
}; };
enum struct ConvWeightLayout enum struct ConvWeightLayout
{ {
KCYX, // 0 KCYX, // 0
KYXC, // 1 KYXC, // 1
}; };
enum struct ConvOutputLayout enum struct ConvOutputLayout
{ {
NKHW, // 0 NKHW, // 0
NHWK, // 1 NHWK, // 1
}; };
int profile_conv_bwd_data(int argc, char* argv[])
{
if(argc != 25) if(argc != 25)
{ {
printf("arg1: tensor operation (conv_bwd: BackwardConvolution)\n"); printf("arg1: tensor operation (conv_bwd: BackwardConvolution)\n");
...@@ -47,7 +48,7 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -47,7 +48,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
printf("arg9: run kernel # of times (>1)\n"); printf("arg9: run kernel # of times (>1)\n");
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n"); "RightPx\n");
exit(1); return false;
} }
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
...@@ -85,14 +86,14 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -85,14 +86,14 @@ int profile_conv_bwd_data(int argc, char* argv[])
if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC && if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{ {
ck::profiler::profile_conv_bwd_data_impl<2, return ck::profiler::profile_conv_bwd_data_impl<2,
float, float,
float, float,
float, float,
float, float,
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::NHWK>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -111,14 +112,14 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -111,14 +112,14 @@ int profile_conv_bwd_data(int argc, char* argv[])
else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{ {
ck::profiler::profile_conv_bwd_data_impl<2, return ck::profiler::profile_conv_bwd_data_impl<2,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float, float,
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::NHWK>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -137,14 +138,14 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -137,14 +138,14 @@ int profile_conv_bwd_data(int argc, char* argv[])
else if(data_type == ConvDataType::BF16_BF16_BF16 && in_layout == ConvInputLayout::NHWC && else if(data_type == ConvDataType::BF16_BF16_BF16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{ {
ck::profiler::profile_conv_bwd_data_impl<2, return ck::profiler::profile_conv_bwd_data_impl<2,
uint16_t, uint16_t,
uint16_t, uint16_t,
uint16_t, uint16_t,
float, float,
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::NHWK>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -163,14 +164,14 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -163,14 +164,14 @@ int profile_conv_bwd_data(int argc, char* argv[])
else if(data_type == ConvDataType::INT8_INT8_INT8 && in_layout == ConvInputLayout::NHWC && else if(data_type == ConvDataType::INT8_INT8_INT8 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{ {
ck::profiler::profile_conv_bwd_data_impl<2, return ck::profiler::profile_conv_bwd_data_impl<2,
int8_t, int8_t,
int8_t, int8_t,
int8_t, int8_t,
int32_t, int32_t,
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::NHWK>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -188,8 +189,8 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -188,8 +189,8 @@ int profile_conv_bwd_data(int argc, char* argv[])
} }
else else
{ {
throw std::runtime_error("wrong! this Conv data_type & layout is not implemented"); std::cout << "this data_type & layout is not implemented" << std::endl;
}
return 1; return true;
}
} }
...@@ -6,34 +6,35 @@ ...@@ -6,34 +6,35 @@
#include <half.hpp> #include <half.hpp>
#include "profile_conv_bwd_weight_impl.hpp" #include "profile_conv_bwd_weight_impl.hpp"
enum struct ConvDataType // return true if test pass
bool profile_conv_bwd_weight(int argc, char* argv[])
{ {
F32_F32_F32, // 0 enum struct ConvDataType
F16_F16_F16, // 1 {
BF16_BF16_BF16, // 2 F32_F32_F32, // 0
INT8_INT8_INT8, // 3 F16_F16_F16, // 1
}; BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
};
enum struct ConvInputLayout enum struct ConvInputLayout
{ {
NCHW, // 0 NCHW, // 0
NHWC, // 1 NHWC, // 1
}; };
enum struct ConvWeightLayout enum struct ConvWeightLayout
{ {
KCYX, // 0 KCYX, // 0
KYXC, // 1 KYXC, // 1
}; };
enum struct ConvOutputLayout enum struct ConvOutputLayout
{ {
NKHW, // 0 NKHW, // 0
NHWK, // 1 NHWK, // 1
}; };
int profile_conv_bwd_weight(int argc, char* argv[])
{
if(argc != 26) if(argc != 26)
{ {
printf("arg1: tensor operation (conv_fwd: ForwardConvolution)\n"); printf("arg1: tensor operation (conv_fwd: ForwardConvolution)\n");
...@@ -48,7 +49,7 @@ int profile_conv_bwd_weight(int argc, char* argv[]) ...@@ -48,7 +49,7 @@ int profile_conv_bwd_weight(int argc, char* argv[])
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n"); "RightPx\n");
printf("arg25: split k (>=1)\n"); printf("arg25: split k (>=1)\n");
exit(1); return false;
} }
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
...@@ -88,13 +89,13 @@ int profile_conv_bwd_weight(int argc, char* argv[]) ...@@ -88,13 +89,13 @@ int profile_conv_bwd_weight(int argc, char* argv[])
if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC && if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{ {
ck::profiler::profile_conv_bwd_weight_impl<2, return ck::profiler::profile_conv_bwd_weight_impl<2,
float, float,
float, float,
float, float,
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::NHWK>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -114,13 +115,13 @@ int profile_conv_bwd_weight(int argc, char* argv[]) ...@@ -114,13 +115,13 @@ int profile_conv_bwd_weight(int argc, char* argv[])
else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{ {
ck::profiler::profile_conv_bwd_weight_impl<2, return ck::profiler::profile_conv_bwd_weight_impl<2,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::NHWK>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -139,8 +140,8 @@ int profile_conv_bwd_weight(int argc, char* argv[]) ...@@ -139,8 +140,8 @@ int profile_conv_bwd_weight(int argc, char* argv[])
} }
else else
{ {
throw std::runtime_error("wrong! this Conv data_type & layout is not implemented"); std::cout << "this data_type & layout is not implemented" << std::endl;
}
return 1; return true;
}
} }
...@@ -6,32 +6,33 @@ ...@@ -6,32 +6,33 @@
#include <half.hpp> #include <half.hpp>
#include "profile_conv_fwd_bias_relu_impl.hpp" #include "profile_conv_fwd_bias_relu_impl.hpp"
enum struct ConvDataType // return true if test pass
bool profile_conv_fwd_bias_relu(int argc, char* argv[])
{ {
F32_F32_F32, // 0 enum struct ConvDataType
F16_F16_F16, // 1 {
}; F32_F32_F32, // 0
F16_F16_F16, // 1
};
enum struct ConvInputLayout enum struct ConvInputLayout
{ {
NCHW, // 0 NCHW, // 0
NHWC, // 1 NHWC, // 1
}; };
enum struct ConvWeightLayout enum struct ConvWeightLayout
{ {
KCYX, // 0 KCYX, // 0
KYXC, // 1 KYXC, // 1
}; };
enum struct ConvOutputLayout enum struct ConvOutputLayout
{ {
NKHW, // 0 NKHW, // 0
NHWK, // 1 NHWK, // 1
}; };
int profile_conv_fwd_bias_relu(int argc, char* argv[])
{
if(argc != 25) if(argc != 25)
{ {
printf("arg1: tensor operation (conv_fwd_bias_relu: ForwardConvolution+Bias+ReLu)\n"); printf("arg1: tensor operation (conv_fwd_bias_relu: ForwardConvolution+Bias+ReLu)\n");
...@@ -45,7 +46,7 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[]) ...@@ -45,7 +46,7 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
printf("arg9: run kernel # of times (>1)\n"); printf("arg9: run kernel # of times (>1)\n");
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n"); "RightPx\n");
exit(1); return false;
} }
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
...@@ -83,13 +84,13 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[]) ...@@ -83,13 +84,13 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{ {
ck::profiler::profile_conv_fwd_bias_relu_impl<2, return ck::profiler::profile_conv_fwd_bias_relu_impl<2,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::NHWK>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -107,8 +108,8 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[]) ...@@ -107,8 +108,8 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
} }
else else
{ {
throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); std::cout << "this data_type & layout is not implemented" << std::endl;
}
return 1; return true;
}
} }
...@@ -6,32 +6,32 @@ ...@@ -6,32 +6,32 @@
#include <half.hpp> #include <half.hpp>
#include "profile_conv_fwd_bias_relu_add_impl.hpp" #include "profile_conv_fwd_bias_relu_add_impl.hpp"
enum struct ConvDataType bool profile_conv_fwd_bias_relu_add(int argc, char* argv[])
{ {
F32_F32_F32, // 0 enum struct ConvDataType
F16_F16_F16, // 1 {
}; F32_F32_F32, // 0
F16_F16_F16, // 1
};
enum struct ConvInputLayout enum struct ConvInputLayout
{ {
NCHW, // 0 NCHW, // 0
NHWC, // 1 NHWC, // 1
}; };
enum struct ConvWeightLayout enum struct ConvWeightLayout
{ {
KCYX, // 0 KCYX, // 0
KYXC, // 1 KYXC, // 1
}; };
enum struct ConvOutputLayout enum struct ConvOutputLayout
{ {
NKHW, // 0 NKHW, // 0
NHWK, // 1 NHWK, // 1
}; };
int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
{
if(argc != 25) if(argc != 25)
{ {
printf( printf(
...@@ -46,7 +46,7 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[]) ...@@ -46,7 +46,7 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
printf("arg9: run kernel # of times (>1)\n"); printf("arg9: run kernel # of times (>1)\n");
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n"); "RightPx\n");
exit(1); return false;
} }
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
...@@ -84,13 +84,14 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[]) ...@@ -84,13 +84,14 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{ {
ck::profiler::profile_conv_fwd_bias_relu_add_impl<2, return ck::profiler::profile_conv_fwd_bias_relu_add_impl<
ck::half_t, 2,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::convolution::NHWC, ck::half_t,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -108,8 +109,8 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[]) ...@@ -108,8 +109,8 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
} }
else else
{ {
throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); std::cout << "this data_type & layout is not implemented" << std::endl;
}
return 1; return true;
}
} }
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "profile_conv_fwd_bias_relu_atomic_add_impl.hpp"
enum struct ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
};
enum struct ConvInputLayout
{
NCHW, // 0
NHWC, // 1
};
enum struct ConvWeightLayout
{
KCYX, // 0
KYXC, // 1
};
enum struct ConvOutputLayout
{
NKHW, // 0
NHWK, // 1
};
int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
{
if(argc != 25)
{
printf("arg1: tensor operation (conv_fwd_bias_relu_atomic_add: "
"ForwardConvolution+Bias+ReLu+AtomicAdd)\n");
printf("arg2: data type (0: fp32; 1: fp16)\n");
printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n");
printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n");
printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n");
printf("arg6: verification (0: no; 1: yes)\n");
printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg9: run kernel # of times (>1)\n");
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
exit(1);
}
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]);
const int nrepeat = std::stoi(argv[9]);
const ck::index_t N = std::stoi(argv[10]);
const ck::index_t K = std::stoi(argv[11]);
const ck::index_t C = std::stoi(argv[12]);
const ck::index_t Y = std::stoi(argv[13]);
const ck::index_t X = std::stoi(argv[14]);
const ck::index_t Hi = std::stoi(argv[15]);
const ck::index_t Wi = std::stoi(argv[16]);
const ck::index_t conv_stride_h = std::stoi(argv[17]);
const ck::index_t conv_stride_w = std::stoi(argv[18]);
const ck::index_t conv_dilation_h = std::stoi(argv[19]);
const ck::index_t conv_dilation_w = std::stoi(argv[20]);
const ck::index_t in_left_pad_h = std::stoi(argv[21]);
const ck::index_t in_left_pad_w = std::stoi(argv[22]);
const ck::index_t in_right_pad_h = std::stoi(argv[23]);
const ck::index_t in_right_pad_w = std::stoi(argv[24]);
const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1;
const ck::index_t XEff = (X - 1) * conv_dilation_w + 1;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{
ck::profiler::profile_conv_fwd_bias_relu_atomic_add_impl<
2,
ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>(
do_verification,
init_method,
do_log,
nrepeat,
N,
K,
C,
std::vector<ck::index_t>{Hi, Wi},
std::vector<ck::index_t>{Y, X},
std::vector<ck::index_t>{Ho, Wo},
std::vector<ck::index_t>{conv_stride_h, conv_stride_w},
std::vector<ck::index_t>{conv_dilation_h, conv_dilation_w},
std::vector<ck::index_t>{in_left_pad_h, in_left_pad_w},
std::vector<ck::index_t>{in_right_pad_h, in_right_pad_w});
}
else
{
throw std::runtime_error("wrong! data_type & layout for this operator is not implemented");
}
return 1;
}
...@@ -80,7 +80,7 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[], ...@@ -80,7 +80,7 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[],
} // namespace } // namespace
int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial) bool profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
{ {
const int preParams = 10; const int preParams = 10;
int conv_args = 3 + num_dim_spatial * 6; int conv_args = 3 + num_dim_spatial * 6;
...@@ -98,7 +98,7 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial) ...@@ -98,7 +98,7 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
printf("arg9: run kernel # of times (>1)\n"); printf("arg9: run kernel # of times (>1)\n");
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n"); "RightPx\n");
return 1; return false;
} }
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
...@@ -121,14 +121,14 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial) ...@@ -121,14 +121,14 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 1: case 1:
ck::profiler::profile_convnd_bwd_data_impl<1, return ck::profiler::profile_convnd_bwd_data_impl<1,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType, AccDataType,
ck::tensor_layout::convolution::NWC, ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC, ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK>( ck::tensor_layout::convolution::NWK>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -146,14 +146,14 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial) ...@@ -146,14 +146,14 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
break; break;
case 2: case 2:
ck::profiler::profile_convnd_bwd_data_impl<2, return ck::profiler::profile_convnd_bwd_data_impl<2,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType, AccDataType,
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::NHWK>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -171,58 +171,58 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial) ...@@ -171,58 +171,58 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
break; break;
case 3: case 3:
ck::profiler::profile_convnd_bwd_data_impl<3, return ck::profiler::profile_convnd_bwd_data_impl<
InDataType, 3,
WeiDataType, InDataType,
OutDataType, WeiDataType,
AccDataType, OutDataType,
ck::tensor_layout::convolution::NDHWC, AccDataType,
ck::tensor_layout::convolution::KZYXC, ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::NDHWK>( ck::tensor_layout::convolution::KZYXC,
do_verification, ck::tensor_layout::convolution::NDHWK>(do_verification,
init_method, init_method,
do_log, do_log,
nrepeat, nrepeat,
params.N, params.N,
params.K, params.K,
params.C, params.C,
params.input_spatial_lengths, params.input_spatial_lengths,
params.filter_spatial_lengths, params.filter_spatial_lengths,
params.GetOutputSpatialLengths(), params.GetOutputSpatialLengths(),
params.conv_filter_strides, params.conv_filter_strides,
params.conv_filter_dilations, params.conv_filter_dilations,
params.input_left_pads, params.input_left_pads,
params.input_right_pads); params.input_right_pads);
break; break;
default: break; default: return false;
} }
}; };
if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC && if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{ {
Run(float{}, float{}, float{}, float{}); return Run(float{}, float{}, float{}, float{});
} }
else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{ {
Run(ck::half_t{}, ck::half_t{}, ck::half_t{}, float{}); return Run(ck::half_t{}, ck::half_t{}, ck::half_t{}, float{});
} }
else if(data_type == ConvDataType::BF16_BF16_BF16 && in_layout == ConvInputLayout::NHWC && else if(data_type == ConvDataType::BF16_BF16_BF16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{ {
Run(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, float{}); return Run(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, float{});
} }
else if(data_type == ConvDataType::INT8_INT8_INT8 && in_layout == ConvInputLayout::NHWC && else if(data_type == ConvDataType::INT8_INT8_INT8 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{ {
Run(int8_t{}, int8_t{}, int8_t{}, int32_t{}); return Run(int8_t{}, int8_t{}, int8_t{}, int32_t{});
} }
else else
{ {
std::cout << "wrong! this Conv data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
return 0; return true;
}
} }
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include "conv_fwd_util.hpp" #include "conv_fwd_util.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "fill.hpp" #include "fill.hpp"
#include "profile_convnd_fwd.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
namespace { namespace {
...@@ -295,7 +294,7 @@ void profile_convnd_instances(ConvDataType data_type, ...@@ -295,7 +294,7 @@ void profile_convnd_instances(ConvDataType data_type,
} // namespace } // namespace
int ck::profiler::profile_convnd_fwd(int argc, char* argv[]) bool profile_convnd_fwd(int argc, char* argv[])
{ {
using namespace ck::utils::conv; using namespace ck::utils::conv;
...@@ -347,5 +346,6 @@ int ck::profiler::profile_convnd_fwd(int argc, char* argv[]) ...@@ -347,5 +346,6 @@ int ck::profiler::profile_convnd_fwd(int argc, char* argv[])
std::to_string(num_dim_spatial)); std::to_string(num_dim_spatial));
} }
return 1; // FIXME: return true if test pass, return false if test fail
return true;
} }
...@@ -4,31 +4,29 @@ ...@@ -4,31 +4,29 @@
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> #include <half.hpp>
#include "profile_gemm_impl.hpp" #include "profile_gemm_impl.hpp"
enum struct GemmMatrixLayout // return true if test pass
bool profile_gemm(int argc, char* argv[])
{ {
MK_KN_MN, // 0 enum struct GemmMatrixLayout
MK_NK_MN, // 1 {
KM_KN_MN, // 2 MK_KN_MN, // 0
KM_NK_MN, // 3 MK_NK_MN, // 1
MK_KN_NM, // 4 KM_KN_MN, // 2
MK_NK_NM, // 5 KM_NK_MN, // 3
KM_KN_NM, // 6 };
KM_NK_NM, // 7
};
enum struct GemmDataType enum struct GemmDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
BF16_BF16_BF16, // 2 BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3 INT8_INT8_INT8, // 3
}; };
int profile_gemm(int argc, char* argv[]) if(argc != 14)
{
if(!(argc == 14 || argc == 15))
{ {
printf("arg1: tensor operation (gemm: GEMM)\n"); printf("arg1: tensor operation (gemm: GEMM)\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
...@@ -41,8 +39,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -41,8 +39,7 @@ int profile_gemm(int argc, char* argv[])
printf("arg8: print tensor value (0: no; 1: yes)\n"); printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n"); printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: split k into mulitiple batch\n"); return false;
exit(1);
} }
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
...@@ -59,18 +56,15 @@ int profile_gemm(int argc, char* argv[]) ...@@ -59,18 +56,15 @@ int profile_gemm(int argc, char* argv[])
const int StrideA = std::stoi(argv[11]); const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]); const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]); const int StrideC = std::stoi(argv[13]);
int KBatch = 1;
if(argc == 15)
KBatch = std::stoi(argv[14]);
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_gemm_impl<ck::half_t, return ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -80,17 +74,16 @@ int profile_gemm(int argc, char* argv[]) ...@@ -80,17 +74,16 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? K : StrideA, (StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_gemm_impl<ck::half_t, return ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -100,17 +93,16 @@ int profile_gemm(int argc, char* argv[]) ...@@ -100,17 +93,16 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? K : StrideA, (StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_gemm_impl<ck::half_t, return ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -120,17 +112,16 @@ int profile_gemm(int argc, char* argv[]) ...@@ -120,17 +112,16 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_gemm_impl<ck::half_t, return ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -140,17 +131,16 @@ int profile_gemm(int argc, char* argv[]) ...@@ -140,17 +131,16 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_gemm_impl<float, return ck::profiler::profile_gemm_impl<float,
float, float,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -160,17 +150,16 @@ int profile_gemm(int argc, char* argv[]) ...@@ -160,17 +150,16 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? K : StrideA, (StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_gemm_impl<float, return ck::profiler::profile_gemm_impl<float,
float, float,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -180,17 +169,16 @@ int profile_gemm(int argc, char* argv[]) ...@@ -180,17 +169,16 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? K : StrideA, (StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_gemm_impl<float, return ck::profiler::profile_gemm_impl<float,
float, float,
float, float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -200,17 +188,16 @@ int profile_gemm(int argc, char* argv[]) ...@@ -200,17 +188,16 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_gemm_impl<float, return ck::profiler::profile_gemm_impl<float,
float, float,
float, float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -220,17 +207,16 @@ int profile_gemm(int argc, char* argv[]) ...@@ -220,17 +207,16 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_gemm_impl<int8_t, return ck::profiler::profile_gemm_impl<int8_t,
int8_t, int8_t,
int8_t, int8_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -240,17 +226,16 @@ int profile_gemm(int argc, char* argv[]) ...@@ -240,17 +226,16 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? K : StrideA, (StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_gemm_impl<int8_t, return ck::profiler::profile_gemm_impl<int8_t,
int8_t, int8_t,
int8_t, int8_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -260,17 +245,16 @@ int profile_gemm(int argc, char* argv[]) ...@@ -260,17 +245,16 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_gemm_impl<int8_t, return ck::profiler::profile_gemm_impl<int8_t,
int8_t, int8_t,
int8_t, int8_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -280,17 +264,16 @@ int profile_gemm(int argc, char* argv[]) ...@@ -280,17 +264,16 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_gemm_impl<int8_t, return ck::profiler::profile_gemm_impl<int8_t,
int8_t, int8_t,
int8_t, int8_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -300,17 +283,16 @@ int profile_gemm(int argc, char* argv[]) ...@@ -300,17 +283,16 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_gemm_impl<ck::bhalf_t, return ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -320,17 +302,16 @@ int profile_gemm(int argc, char* argv[]) ...@@ -320,17 +302,16 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? K : StrideA, (StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_gemm_impl<ck::bhalf_t, return ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -340,17 +321,16 @@ int profile_gemm(int argc, char* argv[]) ...@@ -340,17 +321,16 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_gemm_impl<ck::bhalf_t, return ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -360,17 +340,16 @@ int profile_gemm(int argc, char* argv[]) ...@@ -360,17 +340,16 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_gemm_impl<ck::bhalf_t, return ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -380,13 +359,12 @@ int profile_gemm(int argc, char* argv[]) ...@@ -380,13 +359,12 @@ int profile_gemm(int argc, char* argv[])
K, K,
(StrideA < 0) ? M : StrideA, (StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC);
KBatch);
} }
else else
{ {
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); std::cout << "this data_type & layout is not implemented" << std::endl;
}
return 1; return true;
}
} }
...@@ -6,27 +6,27 @@ ...@@ -6,27 +6,27 @@
#include <half.hpp> #include <half.hpp>
#include "profile_gemm_bias_2d_impl.hpp" #include "profile_gemm_bias_2d_impl.hpp"
enum struct GemmMatrixLayout bool profile_gemm_bias_2d(int argc, char* argv[])
{ {
MK_KN_MN, // 0 enum struct GemmMatrixLayout
MK_NK_MN, // 1 {
KM_KN_MN, // 2 MK_KN_MN, // 0
KM_NK_MN, // 3 MK_NK_MN, // 1
MK_KN_NM, // 4 KM_KN_MN, // 2
MK_NK_NM, // 5 KM_NK_MN, // 3
KM_KN_NM, // 6 MK_KN_NM, // 4
KM_NK_NM, // 7 MK_NK_NM, // 5
}; KM_KN_NM, // 6
KM_NK_NM, // 7
};
enum struct GemmDataType enum struct GemmDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
}; };
int profile_gemm_bias_2d(int argc, char* argv[]) if(argc != 16)
{
if(!(argc == 16 || argc == 17))
{ {
printf("arg1: tensor operation (gemm: GEMM+Bias_2d)\n"); printf("arg1: tensor operation (gemm: GEMM+Bias_2d)\n");
printf("arg2: data type (0: fp32; 1: fp16)\n"); printf("arg2: data type (0: fp32; 1: fp16)\n");
...@@ -41,8 +41,7 @@ int profile_gemm_bias_2d(int argc, char* argv[]) ...@@ -41,8 +41,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: alpha\n"); printf("arg14: alpha\n");
printf("arg15: beta\n"); printf("arg15: beta\n");
printf("arg16: split k into mulitiple batch\n"); return false;
exit(1);
} }
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
...@@ -65,14 +64,14 @@ int profile_gemm_bias_2d(int argc, char* argv[]) ...@@ -65,14 +64,14 @@ int profile_gemm_bias_2d(int argc, char* argv[])
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_gemm_bias_2d_impl<float, return ck::profiler::profile_gemm_bias_2d_impl<float,
float, float,
float, float,
float, float,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -88,14 +87,14 @@ int profile_gemm_bias_2d(int argc, char* argv[]) ...@@ -88,14 +87,14 @@ int profile_gemm_bias_2d(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_gemm_bias_2d_impl<float, return ck::profiler::profile_gemm_bias_2d_impl<float,
float, float,
float, float,
float, float,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -111,14 +110,14 @@ int profile_gemm_bias_2d(int argc, char* argv[]) ...@@ -111,14 +110,14 @@ int profile_gemm_bias_2d(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_gemm_bias_2d_impl<float, return ck::profiler::profile_gemm_bias_2d_impl<float,
float, float,
float, float,
float, float,
float, float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -134,14 +133,14 @@ int profile_gemm_bias_2d(int argc, char* argv[]) ...@@ -134,14 +133,14 @@ int profile_gemm_bias_2d(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_gemm_bias_2d_impl<float, return ck::profiler::profile_gemm_bias_2d_impl<float,
float, float,
float, float,
float, float,
float, float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -157,14 +156,14 @@ int profile_gemm_bias_2d(int argc, char* argv[]) ...@@ -157,14 +156,14 @@ int profile_gemm_bias_2d(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_gemm_bias_2d_impl<ck::half_t, return ck::profiler::profile_gemm_bias_2d_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -180,14 +179,14 @@ int profile_gemm_bias_2d(int argc, char* argv[]) ...@@ -180,14 +179,14 @@ int profile_gemm_bias_2d(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_gemm_bias_2d_impl<ck::half_t, return ck::profiler::profile_gemm_bias_2d_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -203,14 +202,14 @@ int profile_gemm_bias_2d(int argc, char* argv[]) ...@@ -203,14 +202,14 @@ int profile_gemm_bias_2d(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_gemm_bias_2d_impl<ck::half_t, return ck::profiler::profile_gemm_bias_2d_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float, float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -226,14 +225,14 @@ int profile_gemm_bias_2d(int argc, char* argv[]) ...@@ -226,14 +225,14 @@ int profile_gemm_bias_2d(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_gemm_bias_2d_impl<ck::half_t, return ck::profiler::profile_gemm_bias_2d_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float, float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -249,8 +248,8 @@ int profile_gemm_bias_2d(int argc, char* argv[]) ...@@ -249,8 +248,8 @@ int profile_gemm_bias_2d(int argc, char* argv[])
} }
else else
{ {
throw std::runtime_error("wrong! this data_type & layout is not implemented"); std::cout << "this data_type & layout is not implemented" << std::endl;
}
return 1; return true;
}
} }
...@@ -6,26 +6,26 @@ ...@@ -6,26 +6,26 @@
#include <half.hpp> #include <half.hpp>
#include "profile_gemm_bias_relu_impl.hpp" #include "profile_gemm_bias_relu_impl.hpp"
enum struct GemmMatrixLayout bool profile_gemm_bias_relu(int argc, char* argv[])
{ {
MK_KN_MN, // 0 enum struct GemmMatrixLayout
MK_NK_MN, // 1 {
KM_KN_MN, // 2 MK_KN_MN, // 0
KM_NK_MN, // 3 MK_NK_MN, // 1
MK_KN_NM, // 4 KM_KN_MN, // 2
MK_NK_NM, // 5 KM_NK_MN, // 3
KM_KN_NM, // 6 MK_KN_NM, // 4
KM_NK_NM, // 7 MK_NK_NM, // 5
}; KM_KN_NM, // 6
KM_NK_NM, // 7
};
enum struct GemmDataType enum struct GemmDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
}; };
int profile_gemm_bias_relu(int argc, char* argv[])
{
if(!(argc == 14 || argc == 15)) if(!(argc == 14 || argc == 15))
{ {
printf("arg1: tensor operation (gemm: GEMM+Bias+ReLU)\n"); printf("arg1: tensor operation (gemm: GEMM+Bias+ReLU)\n");
...@@ -40,7 +40,7 @@ int profile_gemm_bias_relu(int argc, char* argv[]) ...@@ -40,7 +40,7 @@ int profile_gemm_bias_relu(int argc, char* argv[])
printf("arg7: run kernel # of times (>1)\n"); printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: split k into mulitiple batch\n"); printf("arg14: split k into mulitiple batch\n");
exit(1); return false;
} }
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
...@@ -60,12 +60,12 @@ int profile_gemm_bias_relu(int argc, char* argv[]) ...@@ -60,12 +60,12 @@ int profile_gemm_bias_relu(int argc, char* argv[])
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_gemm_bias_relu_impl<ck::half_t, return ck::profiler::profile_gemm_bias_relu_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -79,12 +79,12 @@ int profile_gemm_bias_relu(int argc, char* argv[]) ...@@ -79,12 +79,12 @@ int profile_gemm_bias_relu(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_gemm_bias_relu_impl<ck::half_t, return ck::profiler::profile_gemm_bias_relu_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -98,12 +98,12 @@ int profile_gemm_bias_relu(int argc, char* argv[]) ...@@ -98,12 +98,12 @@ int profile_gemm_bias_relu(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_gemm_bias_relu_impl<ck::half_t, return ck::profiler::profile_gemm_bias_relu_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -117,12 +117,12 @@ int profile_gemm_bias_relu(int argc, char* argv[]) ...@@ -117,12 +117,12 @@ int profile_gemm_bias_relu(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_gemm_bias_relu_impl<ck::half_t, return ck::profiler::profile_gemm_bias_relu_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -136,8 +136,8 @@ int profile_gemm_bias_relu(int argc, char* argv[]) ...@@ -136,8 +136,8 @@ int profile_gemm_bias_relu(int argc, char* argv[])
} }
else else
{ {
throw std::runtime_error("wrong! this data_type & layout is not implemented"); std::cout << "this data_type & layout is not implemented" << std::endl;
}
return 1; return true;
}
} }
...@@ -6,27 +6,27 @@ ...@@ -6,27 +6,27 @@
#include <half.hpp> #include <half.hpp>
#include "profile_gemm_bias_relu_add_impl.hpp" #include "profile_gemm_bias_relu_add_impl.hpp"
enum struct GemmMatrixLayout bool profile_gemm_bias_relu_add(int argc, char* argv[])
{ {
MK_KN_MN, // 0 enum struct GemmMatrixLayout
MK_NK_MN, // 1 {
KM_KN_MN, // 2 MK_KN_MN, // 0
KM_NK_MN, // 3 MK_NK_MN, // 1
MK_KN_NM, // 4 KM_KN_MN, // 2
MK_NK_NM, // 5 KM_NK_MN, // 3
KM_KN_NM, // 6 MK_KN_NM, // 4
KM_NK_NM, // 7 MK_NK_NM, // 5
}; KM_KN_NM, // 6
KM_NK_NM, // 7
};
enum struct GemmDataType enum struct GemmDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
}; };
int profile_gemm_bias_relu_add(int argc, char* argv[]) if(argc != 15)
{
if(!(argc == 15 || argc == 16))
{ {
printf("arg1: tensor operation (gemm: GEMM+Bias+ReLU+Add)\n"); printf("arg1: tensor operation (gemm: GEMM+Bias+ReLU+Add)\n");
printf("arg2: data type (0: fp32; 1: fp16)\n"); printf("arg2: data type (0: fp32; 1: fp16)\n");
...@@ -39,8 +39,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) ...@@ -39,8 +39,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
printf("arg8: print tensor value (0: no; 1: yes)\n"); printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n"); printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, StrideC1\n"); printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, StrideC1\n");
printf("arg15: split k into mulitiple batch\n"); return false;
exit(1);
} }
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
...@@ -61,12 +60,12 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) ...@@ -61,12 +60,12 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t, return ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -81,12 +80,12 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) ...@@ -81,12 +80,12 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t, return ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -101,12 +100,12 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) ...@@ -101,12 +100,12 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t, return ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -121,12 +120,12 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) ...@@ -121,12 +120,12 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t, return ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -141,8 +140,8 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) ...@@ -141,8 +140,8 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
} }
else else
{ {
throw std::runtime_error("wrong! this data_type & layout is not implemented"); std::cout << "this data_type & layout is not implemented" << std::endl;
}
return 1; return true;
}
} }
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
#include <half.hpp> #include <half.hpp>
#include "profile_gemm_reduce_impl.hpp" #include "profile_gemm_reduce_impl.hpp"
int profile_gemm_reduce(int argc, char* argv[]) // return true if test pass
bool profile_gemm_reduce(int argc, char* argv[])
{ {
enum struct GemmMatrixLayout enum struct GemmMatrixLayout
{ {
...@@ -22,7 +23,7 @@ int profile_gemm_reduce(int argc, char* argv[]) ...@@ -22,7 +23,7 @@ int profile_gemm_reduce(int argc, char* argv[])
F16_F16_F16_F32_F32, // 1 F16_F16_F16_F32_F32, // 1
}; };
if(!(argc == 14 || argc == 15)) if(argc != 14)
{ {
printf("arg1: tensor operation (gemm: GEMM+Reduce)\n"); printf("arg1: tensor operation (gemm: GEMM+Reduce)\n");
printf("arg2: data type (0: fp32; 1: fp16)\n"); printf("arg2: data type (0: fp32; 1: fp16)\n");
...@@ -34,9 +35,7 @@ int profile_gemm_reduce(int argc, char* argv[]) ...@@ -34,9 +35,7 @@ int profile_gemm_reduce(int argc, char* argv[])
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg8: print tensor value (0: no; 1: yes)\n"); printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n"); printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); return false;
printf("arg14: split k into mulitiple batch\n");
exit(1);
} }
const auto data_type = static_cast<GemmReduceDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmReduceDataType>(std::stoi(argv[2]));
...@@ -56,13 +55,13 @@ int profile_gemm_reduce(int argc, char* argv[]) ...@@ -56,13 +55,13 @@ int profile_gemm_reduce(int argc, char* argv[])
if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_gemm_reduce_impl<ck::half_t, return ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -77,13 +76,13 @@ int profile_gemm_reduce(int argc, char* argv[]) ...@@ -77,13 +76,13 @@ int profile_gemm_reduce(int argc, char* argv[])
else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout::MK_NK_MN) layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_gemm_reduce_impl<ck::half_t, return ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -98,13 +97,13 @@ int profile_gemm_reduce(int argc, char* argv[]) ...@@ -98,13 +97,13 @@ int profile_gemm_reduce(int argc, char* argv[])
else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout::KM_KN_MN) layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_gemm_reduce_impl<ck::half_t, return ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float, float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -119,13 +118,13 @@ int profile_gemm_reduce(int argc, char* argv[]) ...@@ -119,13 +118,13 @@ int profile_gemm_reduce(int argc, char* argv[])
else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout::KM_NK_MN) layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_gemm_reduce_impl<ck::half_t, return ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float, float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -139,8 +138,8 @@ int profile_gemm_reduce(int argc, char* argv[]) ...@@ -139,8 +138,8 @@ int profile_gemm_reduce(int argc, char* argv[])
} }
else else
{ {
throw std::runtime_error("wrong! this data_type & layout is not implemented"); std::cout << "this data_type & layout is not implemented" << std::endl;
}
return 1; return true;
}
} }
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "profile_gemm_splitk_impl.hpp"
// return true if test pass
bool profile_gemm_splitk(int argc, char* argv[])
{
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
};
if(argc != 15)
{
printf("arg1: tensor operation (gemm: GEMM)\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: split k into mulitiple batch\n");
return false;
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const int nrepeat = std::stoi(argv[7]);
const int M = std::stoi(argv[8]);
const int N = std::stoi(argv[9]);
const int K = std::stoi(argv[10]);
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
const int KBatch = std::stoi(argv[14]);
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return ck::profiler::profile_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return ck::profiler::profile_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
return ck::profiler::profile_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
return ck::profiler::profile_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
return ck::profiler::profile_gemm_splitk_impl<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
return ck::profiler::profile_gemm_splitk_impl<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
return ck::profiler::profile_gemm_splitk_impl<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
return ck::profiler::profile_gemm_splitk_impl<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return true;
}
}
...@@ -6,26 +6,6 @@ ...@@ -6,26 +6,6 @@
#include <half.hpp> #include <half.hpp>
#include "profile_grouped_gemm_impl.hpp" #include "profile_grouped_gemm_impl.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
MK_KN_NM, // 4
MK_NK_NM, // 5
KM_KN_NM, // 6
KM_NK_NM, // 7
};
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
};
std::vector<int> argToIntArray(char* input) std::vector<int> argToIntArray(char* input)
{ {
std::vector<int> out; std::vector<int> out;
...@@ -42,9 +22,25 @@ std::vector<int> argToIntArray(char* input) ...@@ -42,9 +22,25 @@ std::vector<int> argToIntArray(char* input)
return out; return out;
} }
int profile_grouped_gemm(int argc, char* argv[]) bool profile_grouped_gemm(int argc, char* argv[])
{ {
if(!(argc == 14)) enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
};
if(argc != 14)
{ {
printf("arg1: tensor operation (grouped_gemm: Grouped GEMM)\n"); printf("arg1: tensor operation (grouped_gemm: Grouped GEMM)\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
...@@ -58,7 +54,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -58,7 +54,7 @@ int profile_grouped_gemm(int argc, char* argv[])
printf("arg7: run kernel # of times (>1)\n"); printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " printf("arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"); "64,64 64,64 128,128)\n");
exit(1); return false;
} }
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
...@@ -78,80 +74,84 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -78,80 +74,84 @@ int profile_grouped_gemm(int argc, char* argv[])
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_impl<ck::half_t, return ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification, ck::tensor_layout::gemm::RowMajor>(
init_method, do_verification,
do_log, init_method,
nrepeat, do_log,
Ms, nrepeat,
Ns, Ms,
Ks, Ns,
StrideAs, Ks,
StrideBs, StrideAs,
StrideCs); StrideBs,
StrideCs);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_grouped_gemm_impl<ck::half_t, return ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification, ck::tensor_layout::gemm::RowMajor>(
init_method, do_verification,
do_log, init_method,
nrepeat, do_log,
Ms, nrepeat,
Ns, Ms,
Ks, Ns,
StrideAs, Ks,
StrideBs, StrideAs,
StrideCs); StrideBs,
StrideCs);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_impl<ck::half_t, return ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification, ck::tensor_layout::gemm::RowMajor>(
init_method, do_verification,
do_log, init_method,
nrepeat, do_log,
Ms, nrepeat,
Ns, Ms,
Ks, Ns,
StrideAs, Ks,
StrideBs, StrideAs,
StrideCs); StrideBs,
StrideCs);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_grouped_gemm_impl<ck::half_t, return ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification, ck::tensor_layout::gemm::RowMajor>(
init_method, do_verification,
do_log, init_method,
nrepeat, do_log,
Ms, nrepeat,
Ns, Ms,
Ks, Ns,
StrideAs, Ks,
StrideBs, StrideAs,
StrideCs); StrideBs,
StrideCs);
} }
else else
{ {
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); std::cout << "this data_type & layout is not implemented" << std::endl;
}
return 1; return true;
}
} }
...@@ -320,7 +320,7 @@ class AppArgs ...@@ -320,7 +320,7 @@ class AppArgs
}; // end of class AppArgs }; // end of class AppArgs
int profile_reduce(int argc, char* argv[]) bool profile_reduce(int argc, char* argv[])
{ {
using namespace ck::profiler; using namespace ck::profiler;
...@@ -499,5 +499,5 @@ int profile_reduce(int argc, char* argv[]) ...@@ -499,5 +499,5 @@ int profile_reduce(int argc, char* argv[])
throw std::runtime_error("Invalid compType assignment!"); throw std::runtime_error("Invalid compType assignment!");
}; };
return (0); return true;
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment